Compare commits

...

31 Commits

Author SHA1 Message Date
Tanguy
41649f0999 Version 1.1.0 (#904) 2023-06-06 11:05:49 +00:00
diegomrsantos
67102873ba Fail only if all addresses in PeerRecord are invalid (#898)
Fixes https://github.com/waku-org/nwaku/issues/1768
2023-05-31 08:59:50 +02:00
diegomrsantos
d40d324160 fix missing import (#897) 2023-05-26 10:36:39 +02:00
diegomrsantos
a677b06273 Try a direct connection only if there isn't one already (#891) 2023-05-25 15:48:22 +02:00
diegomrsantos
6050cdef7e Refinement of Hole Punching Service (#892) 2023-05-25 15:47:00 +02:00
Tanguy
fedfa8e817 Fix bumper CI (#894) 2023-05-22 17:37:42 +02:00
diegomrsantos
6887b43777 Improve utility tests (#893) 2023-05-22 11:07:22 +02:00
Tanguy
225accd11b Less warnings (#813)
Co-authored-by: Diego <diego@status.im>
2023-05-18 10:24:17 +02:00
diegomrsantos
7d6bc545e0 Handle dns addrs in HP service (#890) 2023-05-16 14:59:02 +02:00
Tanguy
a1eb53b181 Fix gossipsub dOut handling (#883) 2023-04-26 13:44:45 +02:00
Tanguy
db629dca25 Fix network protocol metrics typo (#874) 2023-04-26 09:52:06 +02:00
diegomrsantos
a5666789b0 Hole Punching (#806)
Co-authored-by: Tanguy <tanguy@status.im>
2023-04-18 12:50:21 +02:00
diegomrsantos
b7726bf68f Dcutr (#824)
Co-authored-by: Tanguy <tanguy@status.im>
2023-04-14 16:23:19 +02:00
diegomrsantos
0221affe98 Invalid MA is ignored (#881) 2023-04-14 12:05:32 +00:00
diegomrsantos
edbd35b16c Fix interop tests (#882) 2023-04-13 19:38:34 +02:00
diegomrsantos
80cca0ecac Does not allow an empty MA (#877) 2023-04-06 14:19:01 +00:00
diegomrsantos
0041ed4cf8 Transport hole punching (#873)
Co-authored-by: Tanguy <tanguy@status.im>
2023-04-06 15:23:35 +02:00
Tanguy
95e98e8c51 Fix traffic metrics (#879) 2023-04-03 12:37:23 +00:00
Tanguy
4aa615c44c GossipSub: TimedEntry & shortAgent fixes (#858) 2023-04-03 11:05:01 +02:00
Tanguy
6b61ce8c91 GossipSub: Better IWANT handling (#875) 2023-04-03 10:56:20 +02:00
Alvaro Revuelta
53b060f8f0 Add getters for conns and streams (#878) 2023-03-31 00:16:39 +02:00
diegomrsantos
af5299f26c Create an ObservedAddrManager and add an AddressMapper in AutonatService and AutoRelayService (#871)
Co-authored-by: Tanguy <tanguy@status.im>
2023-03-24 15:42:49 +00:00
Tanguy
bac754e2ad Various gossipsub fixes (#827) 2023-03-21 17:13:25 +01:00
Tanguy
8d5ea43e2b Upgrade flow refactoring (#807) 2023-03-08 12:30:19 +01:00
Jacek Sieka
e573238705 reexport public types (#872) 2023-03-06 15:36:10 +00:00
Tanguy
c1a3bd8fee Fix pubsub CI logs (#861) 2023-03-01 16:59:44 +01:00
diegomrsantos
ddeb7b3bd4 Handle when peers ask each other at the same time (#865) 2023-02-21 17:49:41 +01:00
Tanguy
382b992e00 Interop tests (#864) 2023-02-20 14:26:53 +01:00
Tanguy
408dcf12bd Fix backward compatibility of #822 (#862) 2023-02-15 17:18:29 +01:00
Ludovic Chenut
0012b639c8 Fix testrelay (#860) 2023-02-15 11:18:42 +01:00
Tanguy
f7f1e89669 TCP Transport: enable NO_DELAY for clients (#822) 2023-02-14 10:35:44 +01:00
133 changed files with 3281 additions and 1578 deletions

View File

@@ -10,11 +10,12 @@ jobs:
bumpProjects:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
target: [
{ repo: status-im/nimbus-eth2, branch: unstable },
{ repo: status-im/nwaku, branch: master },
{ repo: status-im/nim-codex, branch: main }
{ repo: waku-org/nwaku, branch: master },
{ repo: codex-storage/nim-codex, branch: master }
]
steps:
- name: Clone repo

54
.github/workflows/interop.yml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: Interoperability Testing
on:
pull_request:
push:
branches:
- unstable
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
run-multidim-interop:
name: Run multidimensional interoperability tests
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
with:
repository: libp2p/test-plans
submodules: true
fetch-depth: 0
- name: Build image
run: >
cd multidim-interop/impl/nim/v1.0 &&
make commitSha=$GITHUB_SHA image_name=nim-libp2p-head
- name: Create ping-version.json
run: >
(cat << EOF
{
"id": "nim-libp2p-head",
"containerImageID": "nim-libp2p-head",
"transports": [
"tcp",
"ws"
],
"secureChannels": [
"noise"
],
"muxers": [
"mplex",
"yamux"
]
}
EOF
) > ${{ github.workspace }}/test_head.json
- uses: libp2p/test-plans/.github/actions/run-interop-ping-test@master
with:
test-filter: nim-libp2p-head
extra-versions: ${{ github.workspace }}/test_head.json

22
.pinned
View File

@@ -1,16 +1,16 @@
bearssl;https://github.com/status-im/nim-bearssl@#a647994910904b0103a05db3a5ec1ecfc4d91a88
chronicles;https://github.com/status-im/nim-chronicles@#32ac8679680ea699f7dbc046e8e0131cac97d41a
chronos;https://github.com/status-im/nim-chronos@#75d030ff71264513fb9701c75a326cd36fcb4692
bearssl;https://github.com/status-im/nim-bearssl@#acf9645e328bdcab481cfda1c158e07ecd46bd7b
chronicles;https://github.com/status-im/nim-chronicles@#1e6350870855541b381d77d4659688bc0d2c4227
chronos;https://github.com/status-im/nim-chronos@#ab5a8c2e0f6941fe3debd61dff0293790079d1b0
dnsclient;https://github.com/ba0f3/dnsclient.nim@#fcd7443634b950eaea574e5eaa00a628ae029823
faststreams;https://github.com/status-im/nim-faststreams@#b42daf41d8eb4fbce40add6836bed838f8d85b6f
faststreams;https://github.com/status-im/nim-faststreams@#814f8927e1f356f39219f37f069b83066bcc893a
httputils;https://github.com/status-im/nim-http-utils@#a85bd52ae0a956983ca6b3267c72961d2ec0245f
json_serialization;https://github.com/status-im/nim-json-serialization@#a7d815ed92f200f490c95d3cfd722089cc923ce6
metrics;https://github.com/status-im/nim-metrics@#21e99a2e9d9f80e68bef65c80ef781613005fccb
nimcrypto;https://github.com/cheatfate/nimcrypto@#24e006df85927f64916e60511620583b11403178
metrics;https://github.com/status-im/nim-metrics@#abf3acc7f06cee9ee2c287d2f31413dc3df4c04e
nimcrypto;https://github.com/cheatfate/nimcrypto@#4014ef939b51e02053c2e16dd3481d47bc9267dd
secp256k1;https://github.com/status-im/nim-secp256k1@#fd173fdff863ce2e211cf64c9a03bc7539fe40b0
serialization;https://github.com/status-im/nim-serialization@#d77417cba6896c26287a68e6a95762e45a1b87e5
stew;https://github.com/status-im/nim-stew@#7184d2424dc3945657884646a72715d494917aad
serialization;https://github.com/status-im/nim-serialization@#5b7cea55efeb074daa8abd8146a03a34adb4521a
stew;https://github.com/status-im/nim-stew@#003fe9f0c83c2b0b2ccbd37087e6d1ccd30a3234
testutils;https://github.com/status-im/nim-testutils@#dfc4c1b39f9ded9baf6365014de2b4bfb4dafc34
unittest2;https://github.com/status-im/nim-unittest2@#da8398c45cafd5bd7772da1fc96e3924a18d3823
websock;https://github.com/status-im/nim-websock@#691f069b209d372b1240d5ae1f57fb7bbafeaba7
zlib;https://github.com/status-im/nim-zlib@#6a6670afba6b97b29b920340e2641978c05ab4d8
unittest2;https://github.com/status-im/nim-unittest2@#883c7a50ad3b82158e64d074c5578fe33ab3c452
websock;https://github.com/status-im/nim-websock@#fea05cde8b123b38d1a0a8524b77efbc84daa848
zlib;https://github.com/status-im/nim-zlib@#826e2fc013f55b4478802d4f2e39f187c50d520a

View File

@@ -12,6 +12,7 @@ switch("warning", "LockLevel:off")
if (NimMajor, NimMinor) < (1, 6):
--styleCheck:hint
else:
switch("warningAsError", "UseBase:on")
--styleCheck:error
# Avoid some rare stack corruption while using exceptions with a SEH-enabled

View File

@@ -48,12 +48,9 @@ else:
stream/connection,
transports/transport,
transports/tcptransport,
transports/wstransport,
protocols/secure/noise,
protocols/ping,
cid,
multihash,
multibase,
multicodec,
errors,
switch,

View File

@@ -1,7 +1,7 @@
mode = ScriptMode.Verbose
packageName = "libp2p"
version = "1.0.0"
version = "1.1.0"
author = "Status Research & Development GmbH"
description = "LibP2P implementation"
license = "MIT"

View File

@@ -230,7 +230,7 @@ proc build*(b: SwitchBuilder): Switch
identify = Identify.new(peerInfo, b.sendSignedPeerRecord)
connManager = ConnManager.new(b.maxConnsPerPeer, b.maxConnections, b.maxIn, b.maxOut)
ms = MultistreamSelect.new()
muxedUpgrade = MuxedUpgrade.new(identify, b.muxers, secureManagerInstances, connManager, ms)
muxedUpgrade = MuxedUpgrade.new(b.muxers, secureManagerInstances, connManager, ms)
let
transports = block:
@@ -247,14 +247,13 @@ proc build*(b: SwitchBuilder): Switch
let peerStore =
if isSome(b.peerStoreCapacity):
PeerStore.new(b.peerStoreCapacity.get())
PeerStore.new(identify, b.peerStoreCapacity.get())
else:
PeerStore.new()
PeerStore.new(identify)
let switch = newSwitch(
peerInfo = peerInfo,
transports = transports,
identity = identify,
secureManagers = secureManagerInstances,
connManager = connManager,
ms = ms,
@@ -262,6 +261,8 @@ proc build*(b: SwitchBuilder): Switch
peerStore = peerStore,
services = b.services)
switch.mount(identify)
if b.autonat:
let autonat = Autonat.new(switch)
switch.mount(autonat)
@@ -297,9 +298,10 @@ proc newStandardSwitch*(
peerStoreCapacity = 1000): Switch
{.raises: [Defect, LPError], public.} =
## Helper for common switch configurations.
{.push warning[Deprecated]:off.}
if SecureProtocol.Secio in secureManagers:
quit("Secio is deprecated!") # use of secio is unsafe
{.pop.}
let addrs = when addrs is MultiAddress: @[addrs] else: addrs
var b = SwitchBuilder

View File

@@ -32,6 +32,7 @@ const
type
TooManyConnectionsError* = object of LPError
AlreadyExpectingConnectionError* = object of LPError
ConnEventKind* {.pure.} = enum
Connected, # A connection was made and securely upgraded - there may be
@@ -54,7 +55,6 @@ type
PeerEventKind* {.pure.} = enum
Left,
Identified,
Joined
PeerEvent* = object
@@ -67,19 +67,14 @@ type
PeerEventHandler* =
proc(peerId: PeerId, event: PeerEvent): Future[void] {.gcsafe, raises: [Defect].}
MuxerHolder = object
muxer: Muxer
handle: Future[void]
ConnManager* = ref object of RootObj
maxConnsPerPeer: int
inSema*: AsyncSemaphore
outSema*: AsyncSemaphore
conns: Table[PeerId, HashSet[Connection]]
muxed: Table[Connection, MuxerHolder]
muxed: Table[PeerId, seq[Muxer]]
connEvents: array[ConnEventKind, OrderedSet[ConnEventHandler]]
peerEvents: array[PeerEventKind, OrderedSet[PeerEventHandler]]
expectedConnections: Table[PeerId, Future[Connection]]
expectedConnectionsOverLimit*: Table[(PeerId, Direction), Future[Muxer]]
peerStore*: PeerStore
ConnectionSlot* = object
@@ -109,42 +104,30 @@ proc new*(C: type ConnManager,
outSema: outSema)
proc connCount*(c: ConnManager, peerId: PeerId): int =
c.conns.getOrDefault(peerId).len
c.muxed.getOrDefault(peerId).len
proc connectedPeers*(c: ConnManager, dir: Direction): seq[PeerId] =
var peers = newSeq[PeerId]()
for peerId, conns in c.conns:
if conns.anyIt(it.dir == dir):
for peerId, mux in c.muxed:
if mux.anyIt(it.connection.dir == dir):
peers.add(peerId)
return peers
proc getConnections*(c: ConnManager): Table[PeerId, seq[Muxer]] =
return c.muxed
proc addConnEventHandler*(c: ConnManager,
handler: ConnEventHandler,
kind: ConnEventKind) =
## Add peer event handler - handlers must not raise exceptions!
##
try:
if isNil(handler): return
c.connEvents[kind].incl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
raiseAssert exc.msg
if isNil(handler): return
c.connEvents[kind].incl(handler)
proc removeConnEventHandler*(c: ConnManager,
handler: ConnEventHandler,
kind: ConnEventKind) =
try:
c.connEvents[kind].excl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
raiseAssert exc.msg
proc triggerConnEvent*(c: ConnManager,
peerId: PeerId,
@@ -171,26 +154,12 @@ proc addPeerEventHandler*(c: ConnManager,
##
if isNil(handler): return
try:
c.peerEvents[kind].incl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
raiseAssert exc.msg
c.peerEvents[kind].incl(handler)
proc removePeerEventHandler*(c: ConnManager,
handler: PeerEventHandler,
kind: PeerEventKind) =
try:
c.peerEvents[kind].excl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
raiseAssert exc.msg
c.peerEvents[kind].excl(handler)
proc triggerPeerEvents*(c: ConnManager,
peerId: PeerId,
@@ -201,14 +170,6 @@ proc triggerPeerEvents*(c: ConnManager,
return
try:
let count = c.connCount(peerId)
if event.kind == PeerEventKind.Joined and count != 1:
trace "peer already joined", peer = peerId, event = $event
return
elif event.kind == PeerEventKind.Left and count != 0:
trace "peer still connected or already left", peer = peerId, event = $event
return
trace "triggering peer events", peer = peerId, event = $event
var peerEvents: seq[Future[void]]
@@ -221,31 +182,22 @@ proc triggerPeerEvents*(c: ConnManager,
except CatchableError as exc: # handlers should not raise!
warn "Exception in triggerPeerEvents", exc = exc.msg, peer = peerId
proc expectConnection*(c: ConnManager, p: PeerId): Future[Connection] {.async.} =
proc expectConnection*(c: ConnManager, p: PeerId, dir: Direction): Future[Muxer] {.async.} =
## Wait for a peer to connect to us. This will bypass the `MaxConnectionsPerPeer`
if p in c.expectedConnections:
raise LPError.newException("Already expecting a connection from that peer")
let key = (p, dir)
if key in c.expectedConnectionsOverLimit:
raise newException(AlreadyExpectingConnectionError, "Already expecting an incoming connection from that peer")
let future = newFuture[Connection]()
c.expectedConnections[p] = future
let future = newFuture[Muxer]()
c.expectedConnectionsOverLimit[key] = future
try:
return await future
finally:
c.expectedConnections.del(p)
proc contains*(c: ConnManager, conn: Connection): bool =
## checks if a connection is being tracked by the
## connection manager
##
if isNil(conn):
return
return conn in c.conns.getOrDefault(conn.peerId)
c.expectedConnectionsOverLimit.del(key)
proc contains*(c: ConnManager, peerId: PeerId): bool =
peerId in c.conns
peerId in c.muxed
proc contains*(c: ConnManager, muxer: Muxer): bool =
## checks if a muxer is being tracked by the connection
@@ -253,183 +205,134 @@ proc contains*(c: ConnManager, muxer: Muxer): bool =
##
if isNil(muxer):
return
return false
let conn = muxer.connection
if conn notin c:
return
return muxer in c.muxed.getOrDefault(conn.peerId)
if conn notin c.muxed:
return
proc closeMuxer(muxer: Muxer) {.async.} =
trace "Cleaning up muxer", m = muxer
return muxer == c.muxed.getOrDefault(conn).muxer
proc closeMuxerHolder(muxerHolder: MuxerHolder) {.async.} =
trace "Cleaning up muxer", m = muxerHolder.muxer
await muxerHolder.muxer.close()
if not(isNil(muxerHolder.handle)):
await muxer.close()
if not(isNil(muxer.handler)):
try:
await muxerHolder.handle # TODO noraises?
await muxer.handler # TODO noraises?
except CatchableError as exc:
trace "Exception in close muxer handler", exc = exc.msg
trace "Cleaned up muxer", m = muxerHolder.muxer
proc delConn(c: ConnManager, conn: Connection) =
let peerId = conn.peerId
c.conns.withValue(peerId, peerConns):
peerConns[].excl(conn)
if peerConns[].len == 0:
c.conns.del(peerId) # invalidates `peerConns`
libp2p_peers.set(c.conns.len.int64)
trace "Removed connection", conn
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
## clean connection's resources such as muxers and streams
if isNil(conn):
trace "Wont cleanup a nil connection"
return
# Remove connection from all tables without async breaks
var muxer = some(MuxerHolder())
if not c.muxed.pop(conn, muxer.get()):
muxer = none(MuxerHolder)
delConn(c, conn)
trace "Cleaned up muxer", m = muxer
proc muxCleanup(c: ConnManager, mux: Muxer) {.async.} =
try:
if muxer.isSome:
await closeMuxerHolder(muxer.get())
finally:
await conn.close()
trace "Triggering disconnect events", mux
let peerId = mux.connection.peerId
trace "Connection cleaned up", conn
let muxers = c.muxed.getOrDefault(peerId).filterIt(it != mux)
if muxers.len > 0:
c.muxed[peerId] = muxers
else:
c.muxed.del(peerId)
libp2p_peers.set(c.muxed.len.int64)
await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left))
proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} =
try:
trace "Triggering connect events", conn
conn.upgrade()
if not(c.peerStore.isNil):
c.peerStore.cleanup(peerId)
let peerId = conn.peerId
await c.triggerPeerEvents(
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out))
await c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In))
except CatchableError as exc:
# This is top-level procedure which will work as separate task, so it
# do not need to propagate CancelledError and should handle other errors
warn "Unexpected exception in switch peer connection cleanup",
conn, msg = exc.msg
proc peerCleanup(c: ConnManager, conn: Connection) {.async.} =
try:
trace "Triggering disconnect events", conn
let peerId = conn.peerId
await c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Disconnected))
await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left))
if not(c.peerStore.isNil):
c.peerStore.cleanup(peerId)
except CatchableError as exc:
# This is top-level procedure which will work as separate task, so it
# do not need to propagate CancelledError and should handle other errors
warn "Unexpected exception peer cleanup handler",
conn, msg = exc.msg
mux, msg = exc.msg
proc onClose(c: ConnManager, conn: Connection) {.async.} =
proc onClose(c: ConnManager, mux: Muxer) {.async.} =
## connection close even handler
##
## triggers the connections resource cleanup
##
try:
await conn.join()
trace "Connection closed, cleaning up", conn
await c.cleanupConn(conn)
except CancelledError:
# This is top-level procedure which will work as separate task, so it
# do not need to propagate CancelledError.
debug "Unexpected cancellation in connection manager's cleanup", conn
await mux.connection.join()
trace "Connection closed, cleaning up", mux
except CatchableError as exc:
debug "Unexpected exception in connection manager's cleanup",
errMsg = exc.msg, conn
errMsg = exc.msg, mux
finally:
trace "Triggering peerCleanup", conn
asyncSpawn c.peerCleanup(conn)
await c.muxCleanup(mux)
proc selectConn*(c: ConnManager,
proc selectMuxer*(c: ConnManager,
peerId: PeerId,
dir: Direction): Connection =
dir: Direction): Muxer =
## Select a connection for the provided peer and direction
##
let conns = toSeq(
c.conns.getOrDefault(peerId))
.filterIt( it.dir == dir )
c.muxed.getOrDefault(peerId))
.filterIt( it.connection.dir == dir )
if conns.len > 0:
return conns[0]
proc selectConn*(c: ConnManager, peerId: PeerId): Connection =
proc selectMuxer*(c: ConnManager, peerId: PeerId): Muxer =
## Select a connection for the provided giving priority
## to outgoing connections
##
var conn = c.selectConn(peerId, Direction.Out)
if isNil(conn):
conn = c.selectConn(peerId, Direction.In)
if isNil(conn):
var mux = c.selectMuxer(peerId, Direction.Out)
if isNil(mux):
mux = c.selectMuxer(peerId, Direction.In)
if isNil(mux):
trace "connection not found", peerId
return mux
return conn
proc selectMuxer*(c: ConnManager, conn: Connection): Muxer =
## select the muxer for the provided connection
proc storeMuxer*(c: ConnManager,
muxer: Muxer)
{.raises: [Defect, CatchableError].} =
## store the connection and muxer
##
if isNil(conn):
return
if isNil(muxer):
raise newException(LPError, "muxer cannot be nil")
if conn in c.muxed:
return c.muxed.getOrDefault(conn).muxer
else:
debug "no muxer for connection", conn
if isNil(muxer.connection):
raise newException(LPError, "muxer's connection cannot be nil")
proc storeConn*(c: ConnManager, conn: Connection)
{.raises: [Defect, LPError].} =
## store a connection
##
if isNil(conn):
raise newException(LPError, "Connection cannot be nil")
if conn.closed or conn.atEof:
if muxer.connection.closed or muxer.connection.atEof:
raise newException(LPError, "Connection closed or EOF")
let peerId = conn.peerId
let
peerId = muxer.connection.peerId
dir = muxer.connection.dir
# we use getOrDefault in the if below instead of [] to avoid the KeyError
if peerId in c.expectedConnections and
not(c.expectedConnections.getOrDefault(peerId).finished):
c.expectedConnections.getOrDefault(peerId).complete(conn)
elif c.conns.getOrDefault(peerId).len > c.maxConnsPerPeer:
debug "Too many connections for peer",
conn, conns = c.conns.getOrDefault(peerId).len
if c.muxed.getOrDefault(peerId).len > c.maxConnsPerPeer:
let key = (peerId, dir)
let expectedConn = c.expectedConnectionsOverLimit.getOrDefault(key)
if expectedConn != nil and not expectedConn.finished:
expectedConn.complete(muxer)
else:
debug "Too many connections for peer",
conns = c.muxed.getOrDefault(peerId).len
raise newTooManyConnectionsError()
raise newTooManyConnectionsError()
c.conns.mgetOrPut(peerId, HashSet[Connection]()).incl(conn)
libp2p_peers.set(c.conns.len.int64)
assert muxer notin c.muxed.getOrDefault(peerId)
# Launch on close listener
# All the errors are handled inside `onClose()` procedure.
asyncSpawn c.onClose(conn)
let
newPeer = peerId notin c.muxed
assert newPeer or c.muxed[peerId].len > 0
c.muxed.mgetOrPut(peerId, newSeq[Muxer]()).add(muxer)
libp2p_peers.set(c.muxed.len.int64)
trace "Stored connection",
conn, direction = $conn.dir, connections = c.conns.len
asyncSpawn c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: dir == Direction.In))
if newPeer:
asyncSpawn c.triggerPeerEvents(
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: dir == Direction.Out))
asyncSpawn c.onClose(muxer)
trace "Stored muxer",
muxer, direction = $muxer.connection.dir, peers = c.muxed.len
proc getIncomingSlot*(c: ConnManager): Future[ConnectionSlot] {.async.} =
await c.inSema.acquire()
@@ -472,39 +375,17 @@ proc trackConnection*(cs: ConnectionSlot, conn: Connection) =
asyncSpawn semaphoreMonitor()
proc storeMuxer*(c: ConnManager,
muxer: Muxer,
handle: Future[void] = nil)
{.raises: [Defect, CatchableError].} =
## store the connection and muxer
##
if isNil(muxer):
raise newException(CatchableError, "muxer cannot be nil")
if isNil(muxer.connection):
raise newException(CatchableError, "muxer's connection cannot be nil")
if muxer.connection notin c:
raise newException(CatchableError, "cant add muxer for untracked connection")
c.muxed[muxer.connection] = MuxerHolder(
muxer: muxer,
handle: handle)
trace "Stored muxer",
muxer, handle = not handle.isNil, connections = c.conns.len
asyncSpawn c.onConnUpgraded(muxer.connection)
proc trackMuxer*(cs: ConnectionSlot, mux: Muxer) =
if isNil(mux):
cs.release()
return
cs.trackConnection(mux.connection)
proc getStream*(c: ConnManager,
peerId: PeerId,
dir: Direction): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the provided peer
## with the given direction
muxer: Muxer): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed muxer
##
let muxer = c.selectMuxer(c.selectConn(peerId, dir))
if not(isNil(muxer)):
return await muxer.newStream()
@@ -513,40 +394,25 @@ proc getStream*(c: ConnManager,
## get a muxed stream for the passed peer from any connection
##
let muxer = c.selectMuxer(c.selectConn(peerId))
if not(isNil(muxer)):
return await muxer.newStream()
return await c.getStream(c.selectMuxer(peerId))
proc getStream*(c: ConnManager,
conn: Connection): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed connection
peerId: PeerId,
dir: Direction): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed peer from a connection with `dir`
##
let muxer = c.selectMuxer(conn)
if not(isNil(muxer)):
return await muxer.newStream()
return await c.getStream(c.selectMuxer(peerId, dir))
proc dropPeer*(c: ConnManager, peerId: PeerId) {.async.} =
## drop connections and cleanup resources for peer
##
trace "Dropping peer", peerId
let conns = c.conns.getOrDefault(peerId)
for conn in conns:
trace "Removing connection", conn
delConn(c, conn)
var muxers: seq[MuxerHolder]
for conn in conns:
if conn in c.muxed:
muxers.add c.muxed[conn]
c.muxed.del(conn)
let muxers = c.muxed.getOrDefault(peerId)
for muxer in muxers:
await closeMuxerHolder(muxer)
for conn in conns:
await conn.close()
trace "Dropped peer", peerId
await closeMuxer(muxer)
trace "Peer dropped", peerId
@@ -556,24 +422,18 @@ proc close*(c: ConnManager) {.async.} =
##
trace "Closing ConnManager"
let conns = c.conns
c.conns.clear()
let muxed = c.muxed
c.muxed.clear()
let expected = c.expectedConnections
c.expectedConnections.clear()
let expected = c.expectedConnectionsOverLimit
c.expectedConnectionsOverLimit.clear()
for _, fut in expected:
await fut.cancelAndWait()
for _, muxer in muxed:
await closeMuxerHolder(muxer)
for _, conns2 in conns:
for conn in conns2:
await conn.close()
for _, muxers in muxed:
for mux in muxers:
await closeMuxer(mux)
trace "Closed ConnManager"

View File

@@ -22,7 +22,7 @@ else:
import bearssl/blockx
from stew/assign2 import assign
from stew/ranges/ptr_arith import baseAddr
from stew/ptrops import baseAddr
const
ChaChaPolyKeySize = 32

View File

@@ -20,7 +20,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import bearssl/[ec, rand, hash]
import bearssl/[ec, rand]
import stew/results
from stew/assign2 import assign
export results

View File

@@ -25,6 +25,9 @@ import nimcrypto/utils as ncrutils
import minasn1
export minasn1.Asn1Error
import stew/[results, ctops]
import ../utility
export results
const
@@ -74,7 +77,7 @@ type
EcResult*[T] = Result[T, EcError]
const
EcSupportedCurvesCint* = {cint(Secp256r1), cint(Secp384r1), cint(Secp521r1)}
EcSupportedCurvesCint* = @[cint(Secp256r1), cint(Secp384r1), cint(Secp521r1)]
proc `-`(x: uint32): uint32 {.inline.} =
result = (0xFFFF_FFFF'u32 - x) + 1'u32
@@ -243,7 +246,7 @@ proc random*(
var res = new EcPrivateKey
if ecKeygen(addr rng.vtable, ecimp,
addr res.key, addr res.buffer[0],
cast[cint](kind)) == 0:
safeConvert[cint](kind)) == 0:
err(EcKeyGenError)
else:
ok(res)
@@ -630,11 +633,11 @@ proc init*(key: var EcPrivateKey, data: openArray[byte]): Result[void, Asn1Error
return err(Asn1Error.Incorrect)
if oid == Asn1OidSecp256r1:
curve = cast[cint](Secp256r1)
curve = safeConvert[cint](Secp256r1)
elif oid == Asn1OidSecp384r1:
curve = cast[cint](Secp384r1)
curve = safeConvert[cint](Secp384r1)
elif oid == Asn1OidSecp521r1:
curve = cast[cint](Secp521r1)
curve = safeConvert[cint](Secp521r1)
else:
return err(Asn1Error.Incorrect)
@@ -684,11 +687,11 @@ proc init*(pubkey: var EcPublicKey, data: openArray[byte]): Result[void, Asn1Err
return err(Asn1Error.Incorrect)
if oid == Asn1OidSecp256r1:
curve = cast[cint](Secp256r1)
curve = safeConvert[cint](Secp256r1)
elif oid == Asn1OidSecp384r1:
curve = cast[cint](Secp384r1)
curve = safeConvert[cint](Secp384r1)
elif oid == Asn1OidSecp521r1:
curve = cast[cint](Secp521r1)
curve = safeConvert[cint](Secp521r1)
else:
return err(Asn1Error.Incorrect)
@@ -774,13 +777,13 @@ proc initRaw*(key: var EcPrivateKey, data: openArray[byte]): bool =
## Procedure returns ``true`` on success, ``false`` otherwise.
var curve: cint
if len(data) == SecKey256Length:
curve = cast[cint](Secp256r1)
curve = safeConvert[cint](Secp256r1)
result = true
elif len(data) == SecKey384Length:
curve = cast[cint](Secp384r1)
curve = safeConvert[cint](Secp384r1)
result = true
elif len(data) == SecKey521Length:
curve = cast[cint](Secp521r1)
curve = safeConvert[cint](Secp521r1)
result = true
if result:
result = false
@@ -805,13 +808,13 @@ proc initRaw*(pubkey: var EcPublicKey, data: openArray[byte]): bool =
if len(data) > 0:
if data[0] == 0x04'u8:
if len(data) == PubKey256Length:
curve = cast[cint](Secp256r1)
curve = safeConvert[cint](Secp256r1)
result = true
elif len(data) == PubKey384Length:
curve = cast[cint](Secp384r1)
curve = safeConvert[cint](Secp384r1)
result = true
elif len(data) == PubKey521Length:
curve = cast[cint](Secp521r1)
curve = safeConvert[cint](Secp521r1)
result = true
if result:
result = false

View File

@@ -22,6 +22,9 @@ import nimcrypto/[hash, sha2]
# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
import stew/[results, ctops]
import ../../utility
export results
# This workaround needed because of some bugs in Nim Static[T].
@@ -170,15 +173,15 @@ proc feCopy(h: var Fe, f: Fe) =
h[9] = f9
proc load_3(inp: openArray[byte]): uint64 =
result = cast[uint64](inp[0])
result = result or (cast[uint64](inp[1]) shl 8)
result = result or (cast[uint64](inp[2]) shl 16)
result = safeConvert[uint64](inp[0])
result = result or (safeConvert[uint64](inp[1]) shl 8)
result = result or (safeConvert[uint64](inp[2]) shl 16)
proc load_4(inp: openArray[byte]): uint64 =
result = cast[uint64](inp[0])
result = result or (cast[uint64](inp[1]) shl 8)
result = result or (cast[uint64](inp[2]) shl 16)
result = result or (cast[uint64](inp[3]) shl 24)
result = safeConvert[uint64](inp[0])
result = result or (safeConvert[uint64](inp[1]) shl 8)
result = result or (safeConvert[uint64](inp[2]) shl 16)
result = result or (safeConvert[uint64](inp[3]) shl 24)
proc feFromBytes(h: var Fe, s: openArray[byte]) =
var c0, c1, c2, c3, c4, c5, c6, c7, c8, c9: int64
@@ -299,106 +302,106 @@ proc feMul(h: var Fe, f, g: Fe) =
var f5_2 = 2 * f5
var f7_2 = 2 * f7
var f9_2 = 2 * f9
var f0g0 = cast[int64](f0) * cast[int64](g0)
var f0g1 = cast[int64](f0) * cast[int64](g1)
var f0g2 = cast[int64](f0) * cast[int64](g2)
var f0g3 = cast[int64](f0) * cast[int64](g3)
var f0g4 = cast[int64](f0) * cast[int64](g4)
var f0g5 = cast[int64](f0) * cast[int64](g5)
var f0g6 = cast[int64](f0) * cast[int64](g6)
var f0g7 = cast[int64](f0) * cast[int64](g7)
var f0g8 = cast[int64](f0) * cast[int64](g8)
var f0g9 = cast[int64](f0) * cast[int64](g9)
var f1g0 = cast[int64](f1) * cast[int64](g0)
var f1g1_2 = cast[int64](f1_2) * cast[int64](g1)
var f1g2 = cast[int64](f1) * cast[int64](g2)
var f1g3_2 = cast[int64](f1_2) * cast[int64](g3)
var f1g4 = cast[int64](f1) * cast[int64](g4)
var f1g5_2 = cast[int64](f1_2) * cast[int64](g5)
var f1g6 = cast[int64](f1) * cast[int64](g6)
var f1g7_2 = cast[int64](f1_2) * cast[int64](g7)
var f1g8 = cast[int64](f1) * cast[int64](g8)
var f1g9_38 = cast[int64](f1_2) * cast[int64](g9_19)
var f2g0 = cast[int64](f2) * cast[int64](g0)
var f2g1 = cast[int64](f2) * cast[int64](g1)
var f2g2 = cast[int64](f2) * cast[int64](g2)
var f2g3 = cast[int64](f2) * cast[int64](g3)
var f2g4 = cast[int64](f2) * cast[int64](g4)
var f2g5 = cast[int64](f2) * cast[int64](g5)
var f2g6 = cast[int64](f2) * cast[int64](g6)
var f2g7 = cast[int64](f2) * cast[int64](g7)
var f2g8_19 = cast[int64](f2) * cast[int64](g8_19)
var f2g9_19 = cast[int64](f2) * cast[int64](g9_19)
var f3g0 = cast[int64](f3) * cast[int64](g0)
var f3g1_2 = cast[int64](f3_2) * cast[int64](g1)
var f3g2 = cast[int64](f3) * cast[int64](g2)
var f3g3_2 = cast[int64](f3_2) * cast[int64](g3)
var f3g4 = cast[int64](f3) * cast[int64](g4)
var f3g5_2 = cast[int64](f3_2) * cast[int64](g5)
var f3g6 = cast[int64](f3) * cast[int64](g6)
var f3g7_38 = cast[int64](f3_2) * cast[int64](g7_19)
var f3g8_19 = cast[int64](f3) * cast[int64](g8_19)
var f3g9_38 = cast[int64](f3_2) * cast[int64](g9_19)
var f4g0 = cast[int64](f4) * cast[int64](g0)
var f4g1 = cast[int64](f4) * cast[int64](g1)
var f4g2 = cast[int64](f4) * cast[int64](g2)
var f4g3 = cast[int64](f4) * cast[int64](g3)
var f4g4 = cast[int64](f4) * cast[int64](g4)
var f4g5 = cast[int64](f4) * cast[int64](g5)
var f4g6_19 = cast[int64](f4) * cast[int64](g6_19)
var f4g7_19 = cast[int64](f4) * cast[int64](g7_19)
var f4g8_19 = cast[int64](f4) * cast[int64](g8_19)
var f4g9_19 = cast[int64](f4) * cast[int64](g9_19)
var f5g0 = cast[int64](f5) * cast[int64](g0)
var f5g1_2 = cast[int64](f5_2) * cast[int64](g1)
var f5g2 = cast[int64](f5) * cast[int64](g2)
var f5g3_2 = cast[int64](f5_2) * cast[int64](g3)
var f5g4 = cast[int64](f5) * cast[int64](g4)
var f5g5_38 = cast[int64](f5_2) * cast[int64](g5_19)
var f5g6_19 = cast[int64](f5) * cast[int64](g6_19)
var f5g7_38 = cast[int64](f5_2) * cast[int64](g7_19)
var f5g8_19 = cast[int64](f5) * cast[int64](g8_19)
var f5g9_38 = cast[int64](f5_2) * cast[int64](g9_19)
var f6g0 = cast[int64](f6) * cast[int64](g0)
var f6g1 = cast[int64](f6) * cast[int64](g1)
var f6g2 = cast[int64](f6) * cast[int64](g2)
var f6g3 = cast[int64](f6) * cast[int64](g3)
var f6g4_19 = cast[int64](f6) * cast[int64](g4_19)
var f6g5_19 = cast[int64](f6) * cast[int64](g5_19)
var f6g6_19 = cast[int64](f6) * cast[int64](g6_19)
var f6g7_19 = cast[int64](f6) * cast[int64](g7_19)
var f6g8_19 = cast[int64](f6) * cast[int64](g8_19)
var f6g9_19 = cast[int64](f6) * cast[int64](g9_19)
var f7g0 = cast[int64](f7) * cast[int64](g0)
var f7g1_2 = cast[int64](f7_2) * cast[int64](g1)
var f7g2 = cast[int64](f7) * cast[int64](g2)
var f7g3_38 = cast[int64](f7_2) * cast[int64](g3_19)
var f7g4_19 = cast[int64](f7) * cast[int64](g4_19)
var f7g5_38 = cast[int64](f7_2) * cast[int64](g5_19)
var f7g6_19 = cast[int64](f7) * cast[int64](g6_19)
var f7g7_38 = cast[int64](f7_2) * cast[int64](g7_19)
var f7g8_19 = cast[int64](f7) * cast[int64](g8_19)
var f7g9_38 = cast[int64](f7_2) * cast[int64](g9_19)
var f8g0 = cast[int64](f8) * cast[int64](g0)
var f8g1 = cast[int64](f8) * cast[int64](g1)
var f8g2_19 = cast[int64](f8) * cast[int64](g2_19)
var f8g3_19 = cast[int64](f8) * cast[int64](g3_19)
var f8g4_19 = cast[int64](f8) * cast[int64](g4_19)
var f8g5_19 = cast[int64](f8) * cast[int64](g5_19)
var f8g6_19 = cast[int64](f8) * cast[int64](g6_19)
var f8g7_19 = cast[int64](f8) * cast[int64](g7_19)
var f8g8_19 = cast[int64](f8) * cast[int64](g8_19)
var f8g9_19 = cast[int64](f8) * cast[int64](g9_19)
var f9g0 = cast[int64](f9) * cast[int64](g0)
var f9g1_38 = cast[int64](f9_2) * cast[int64](g1_19)
var f9g2_19 = cast[int64](f9) * cast[int64](g2_19)
var f9g3_38 = cast[int64](f9_2) * cast[int64](g3_19)
var f9g4_19 = cast[int64](f9) * cast[int64](g4_19)
var f9g5_38 = cast[int64](f9_2) * cast[int64](g5_19)
var f9g6_19 = cast[int64](f9) * cast[int64](g6_19)
var f9g7_38 = cast[int64](f9_2) * cast[int64](g7_19)
var f9g8_19 = cast[int64](f9) * cast[int64](g8_19)
var f9g9_38 = cast[int64](f9_2) * cast[int64](g9_19)
var f0g0 = safeConvert[int64](f0) * safeConvert[int64](g0)
var f0g1 = safeConvert[int64](f0) * safeConvert[int64](g1)
var f0g2 = safeConvert[int64](f0) * safeConvert[int64](g2)
var f0g3 = safeConvert[int64](f0) * safeConvert[int64](g3)
var f0g4 = safeConvert[int64](f0) * safeConvert[int64](g4)
var f0g5 = safeConvert[int64](f0) * safeConvert[int64](g5)
var f0g6 = safeConvert[int64](f0) * safeConvert[int64](g6)
var f0g7 = safeConvert[int64](f0) * safeConvert[int64](g7)
var f0g8 = safeConvert[int64](f0) * safeConvert[int64](g8)
var f0g9 = safeConvert[int64](f0) * safeConvert[int64](g9)
var f1g0 = safeConvert[int64](f1) * safeConvert[int64](g0)
var f1g1_2 = safeConvert[int64](f1_2) * safeConvert[int64](g1)
var f1g2 = safeConvert[int64](f1) * safeConvert[int64](g2)
var f1g3_2 = safeConvert[int64](f1_2) * safeConvert[int64](g3)
var f1g4 = safeConvert[int64](f1) * safeConvert[int64](g4)
var f1g5_2 = safeConvert[int64](f1_2) * safeConvert[int64](g5)
var f1g6 = safeConvert[int64](f1) * safeConvert[int64](g6)
var f1g7_2 = safeConvert[int64](f1_2) * safeConvert[int64](g7)
var f1g8 = safeConvert[int64](f1) * safeConvert[int64](g8)
var f1g9_38 = safeConvert[int64](f1_2) * safeConvert[int64](g9_19)
var f2g0 = safeConvert[int64](f2) * safeConvert[int64](g0)
var f2g1 = safeConvert[int64](f2) * safeConvert[int64](g1)
var f2g2 = safeConvert[int64](f2) * safeConvert[int64](g2)
var f2g3 = safeConvert[int64](f2) * safeConvert[int64](g3)
var f2g4 = safeConvert[int64](f2) * safeConvert[int64](g4)
var f2g5 = safeConvert[int64](f2) * safeConvert[int64](g5)
var f2g6 = safeConvert[int64](f2) * safeConvert[int64](g6)
var f2g7 = safeConvert[int64](f2) * safeConvert[int64](g7)
var f2g8_19 = safeConvert[int64](f2) * safeConvert[int64](g8_19)
var f2g9_19 = safeConvert[int64](f2) * safeConvert[int64](g9_19)
var f3g0 = safeConvert[int64](f3) * safeConvert[int64](g0)
var f3g1_2 = safeConvert[int64](f3_2) * safeConvert[int64](g1)
var f3g2 = safeConvert[int64](f3) * safeConvert[int64](g2)
var f3g3_2 = safeConvert[int64](f3_2) * safeConvert[int64](g3)
var f3g4 = safeConvert[int64](f3) * safeConvert[int64](g4)
var f3g5_2 = safeConvert[int64](f3_2) * safeConvert[int64](g5)
var f3g6 = safeConvert[int64](f3) * safeConvert[int64](g6)
var f3g7_38 = safeConvert[int64](f3_2) * safeConvert[int64](g7_19)
var f3g8_19 = safeConvert[int64](f3) * safeConvert[int64](g8_19)
var f3g9_38 = safeConvert[int64](f3_2) * safeConvert[int64](g9_19)
var f4g0 = safeConvert[int64](f4) * safeConvert[int64](g0)
var f4g1 = safeConvert[int64](f4) * safeConvert[int64](g1)
var f4g2 = safeConvert[int64](f4) * safeConvert[int64](g2)
var f4g3 = safeConvert[int64](f4) * safeConvert[int64](g3)
var f4g4 = safeConvert[int64](f4) * safeConvert[int64](g4)
var f4g5 = safeConvert[int64](f4) * safeConvert[int64](g5)
var f4g6_19 = safeConvert[int64](f4) * safeConvert[int64](g6_19)
var f4g7_19 = safeConvert[int64](f4) * safeConvert[int64](g7_19)
var f4g8_19 = safeConvert[int64](f4) * safeConvert[int64](g8_19)
var f4g9_19 = safeConvert[int64](f4) * safeConvert[int64](g9_19)
var f5g0 = safeConvert[int64](f5) * safeConvert[int64](g0)
var f5g1_2 = safeConvert[int64](f5_2) * safeConvert[int64](g1)
var f5g2 = safeConvert[int64](f5) * safeConvert[int64](g2)
var f5g3_2 = safeConvert[int64](f5_2) * safeConvert[int64](g3)
var f5g4 = safeConvert[int64](f5) * safeConvert[int64](g4)
var f5g5_38 = safeConvert[int64](f5_2) * safeConvert[int64](g5_19)
var f5g6_19 = safeConvert[int64](f5) * safeConvert[int64](g6_19)
var f5g7_38 = safeConvert[int64](f5_2) * safeConvert[int64](g7_19)
var f5g8_19 = safeConvert[int64](f5) * safeConvert[int64](g8_19)
var f5g9_38 = safeConvert[int64](f5_2) * safeConvert[int64](g9_19)
var f6g0 = safeConvert[int64](f6) * safeConvert[int64](g0)
var f6g1 = safeConvert[int64](f6) * safeConvert[int64](g1)
var f6g2 = safeConvert[int64](f6) * safeConvert[int64](g2)
var f6g3 = safeConvert[int64](f6) * safeConvert[int64](g3)
var f6g4_19 = safeConvert[int64](f6) * safeConvert[int64](g4_19)
var f6g5_19 = safeConvert[int64](f6) * safeConvert[int64](g5_19)
var f6g6_19 = safeConvert[int64](f6) * safeConvert[int64](g6_19)
var f6g7_19 = safeConvert[int64](f6) * safeConvert[int64](g7_19)
var f6g8_19 = safeConvert[int64](f6) * safeConvert[int64](g8_19)
var f6g9_19 = safeConvert[int64](f6) * safeConvert[int64](g9_19)
var f7g0 = safeConvert[int64](f7) * safeConvert[int64](g0)
var f7g1_2 = safeConvert[int64](f7_2) * safeConvert[int64](g1)
var f7g2 = safeConvert[int64](f7) * safeConvert[int64](g2)
var f7g3_38 = safeConvert[int64](f7_2) * safeConvert[int64](g3_19)
var f7g4_19 = safeConvert[int64](f7) * safeConvert[int64](g4_19)
var f7g5_38 = safeConvert[int64](f7_2) * safeConvert[int64](g5_19)
var f7g6_19 = safeConvert[int64](f7) * safeConvert[int64](g6_19)
var f7g7_38 = safeConvert[int64](f7_2) * safeConvert[int64](g7_19)
var f7g8_19 = safeConvert[int64](f7) * safeConvert[int64](g8_19)
var f7g9_38 = safeConvert[int64](f7_2) * safeConvert[int64](g9_19)
var f8g0 = safeConvert[int64](f8) * safeConvert[int64](g0)
var f8g1 = safeConvert[int64](f8) * safeConvert[int64](g1)
var f8g2_19 = safeConvert[int64](f8) * safeConvert[int64](g2_19)
var f8g3_19 = safeConvert[int64](f8) * safeConvert[int64](g3_19)
var f8g4_19 = safeConvert[int64](f8) * safeConvert[int64](g4_19)
var f8g5_19 = safeConvert[int64](f8) * safeConvert[int64](g5_19)
var f8g6_19 = safeConvert[int64](f8) * safeConvert[int64](g6_19)
var f8g7_19 = safeConvert[int64](f8) * safeConvert[int64](g7_19)
var f8g8_19 = safeConvert[int64](f8) * safeConvert[int64](g8_19)
var f8g9_19 = safeConvert[int64](f8) * safeConvert[int64](g9_19)
var f9g0 = safeConvert[int64](f9) * safeConvert[int64](g0)
var f9g1_38 = safeConvert[int64](f9_2) * safeConvert[int64](g1_19)
var f9g2_19 = safeConvert[int64](f9) * safeConvert[int64](g2_19)
var f9g3_38 = safeConvert[int64](f9_2) * safeConvert[int64](g3_19)
var f9g4_19 = safeConvert[int64](f9) * safeConvert[int64](g4_19)
var f9g5_38 = safeConvert[int64](f9_2) * safeConvert[int64](g5_19)
var f9g6_19 = safeConvert[int64](f9) * safeConvert[int64](g6_19)
var f9g7_38 = safeConvert[int64](f9_2) * safeConvert[int64](g7_19)
var f9g8_19 = safeConvert[int64](f9) * safeConvert[int64](g8_19)
var f9g9_38 = safeConvert[int64](f9_2) * safeConvert[int64](g9_19)
var
c0, c1, c2, c3, c4, c5, c6, c7, c8, c9: int64
h0: int64 = f0g0 + f1g9_38 + f2g8_19 + f3g7_38 + f4g6_19 + f5g5_38 +
@@ -493,7 +496,7 @@ proc verify32(x: openArray[byte], y: openArray[byte]): int32 =
proc feIsNegative(f: Fe): int32 =
var s: array[32, byte]
feToBytes(s, f)
result = cast[int32](s[0] and 1'u8)
result = safeConvert[int32](s[0] and 1'u8)
proc feIsNonZero(f: Fe): int32 =
var s: array[32, byte]
@@ -516,61 +519,61 @@ proc feSq(h: var Fe, f: Fe) =
var f7_38: int32 = 38 * f7
var f8_19: int32 = 19 * f8
var f9_38: int32 = 38 * f9
var f0f0: int64 = f0 * cast[int64](f0)
var f0f1_2: int64 = f0_2 * cast[int64](f1)
var f0f2_2: int64 = f0_2 * cast[int64](f2)
var f0f3_2: int64 = f0_2 * cast[int64](f3)
var f0f4_2: int64 = f0_2 * cast[int64](f4)
var f0f5_2: int64 = f0_2 * cast[int64](f5)
var f0f6_2: int64 = f0_2 * cast[int64](f6)
var f0f7_2: int64 = f0_2 * cast[int64](f7)
var f0f8_2: int64 = f0_2 * cast[int64](f8)
var f0f9_2: int64 = f0_2 * cast[int64](f9)
var f1f1_2: int64 = f1_2 * cast[int64](f1)
var f1f2_2: int64 = f1_2 * cast[int64](f2)
var f1f3_4: int64 = f1_2 * cast[int64](f3_2)
var f1f4_2: int64 = f1_2 * cast[int64](f4)
var f1f5_4: int64 = f1_2 * cast[int64](f5_2)
var f1f6_2: int64 = f1_2 * cast[int64](f6)
var f1f7_4: int64 = f1_2 * cast[int64](f7_2)
var f1f8_2: int64 = f1_2 * cast[int64](f8)
var f1f9_76: int64 = f1_2 * cast[int64](f9_38)
var f2f2: int64 = f2 * cast[int64](f2)
var f2f3_2: int64 = f2_2 * cast[int64](f3)
var f2f4_2: int64 = f2_2 * cast[int64](f4)
var f2f5_2: int64 = f2_2 * cast[int64](f5)
var f2f6_2: int64 = f2_2 * cast[int64](f6)
var f2f7_2: int64 = f2_2 * cast[int64](f7)
var f2f8_38: int64 = f2_2 * cast[int64](f8_19)
var f2f9_38: int64 = f2 * cast[int64](f9_38)
var f3f3_2: int64 = f3_2 * cast[int64](f3)
var f3f4_2: int64 = f3_2 * cast[int64](f4)
var f3f5_4: int64 = f3_2 * cast[int64](f5_2)
var f3f6_2: int64 = f3_2 * cast[int64](f6)
var f3f7_76: int64 = f3_2 * cast[int64](f7_38)
var f3f8_38: int64 = f3_2 * cast[int64](f8_19)
var f3f9_76: int64 = f3_2 * cast[int64](f9_38)
var f4f4: int64 = f4 * cast[int64](f4)
var f4f5_2: int64 = f4_2 * cast[int64](f5)
var f4f6_38: int64 = f4_2 * cast[int64](f6_19)
var f4f7_38: int64 = f4 * cast[int64](f7_38)
var f4f8_38: int64 = f4_2 * cast[int64](f8_19)
var f4f9_38: int64 = f4 * cast[int64](f9_38)
var f5f5_38: int64 = f5 * cast[int64](f5_38)
var f5f6_38: int64 = f5_2 * cast[int64](f6_19)
var f5f7_76: int64 = f5_2 * cast[int64](f7_38)
var f5f8_38: int64 = f5_2 * cast[int64](f8_19)
var f5f9_76: int64 = f5_2 * cast[int64](f9_38)
var f6f6_19: int64 = f6 * cast[int64](f6_19)
var f6f7_38: int64 = f6 * cast[int64](f7_38)
var f6f8_38: int64 = f6_2 * cast[int64](f8_19)
var f6f9_38: int64 = f6 * cast[int64](f9_38)
var f7f7_38: int64 = f7 * cast[int64](f7_38)
var f7f8_38: int64 = f7_2 * cast[int64](f8_19)
var f7f9_76: int64 = f7_2 * cast[int64](f9_38)
var f8f8_19: int64 = f8 * cast[int64](f8_19)
var f8f9_38: int64 = f8 * cast[int64](f9_38)
var f9f9_38: int64 = f9 * cast[int64](f9_38)
var f0f0: int64 = f0 * safeConvert[int64](f0)
var f0f1_2: int64 = f0_2 * safeConvert[int64](f1)
var f0f2_2: int64 = f0_2 * safeConvert[int64](f2)
var f0f3_2: int64 = f0_2 * safeConvert[int64](f3)
var f0f4_2: int64 = f0_2 * safeConvert[int64](f4)
var f0f5_2: int64 = f0_2 * safeConvert[int64](f5)
var f0f6_2: int64 = f0_2 * safeConvert[int64](f6)
var f0f7_2: int64 = f0_2 * safeConvert[int64](f7)
var f0f8_2: int64 = f0_2 * safeConvert[int64](f8)
var f0f9_2: int64 = f0_2 * safeConvert[int64](f9)
var f1f1_2: int64 = f1_2 * safeConvert[int64](f1)
var f1f2_2: int64 = f1_2 * safeConvert[int64](f2)
var f1f3_4: int64 = f1_2 * safeConvert[int64](f3_2)
var f1f4_2: int64 = f1_2 * safeConvert[int64](f4)
var f1f5_4: int64 = f1_2 * safeConvert[int64](f5_2)
var f1f6_2: int64 = f1_2 * safeConvert[int64](f6)
var f1f7_4: int64 = f1_2 * safeConvert[int64](f7_2)
var f1f8_2: int64 = f1_2 * safeConvert[int64](f8)
var f1f9_76: int64 = f1_2 * safeConvert[int64](f9_38)
var f2f2: int64 = f2 * safeConvert[int64](f2)
var f2f3_2: int64 = f2_2 * safeConvert[int64](f3)
var f2f4_2: int64 = f2_2 * safeConvert[int64](f4)
var f2f5_2: int64 = f2_2 * safeConvert[int64](f5)
var f2f6_2: int64 = f2_2 * safeConvert[int64](f6)
var f2f7_2: int64 = f2_2 * safeConvert[int64](f7)
var f2f8_38: int64 = f2_2 * safeConvert[int64](f8_19)
var f2f9_38: int64 = f2 * safeConvert[int64](f9_38)
var f3f3_2: int64 = f3_2 * safeConvert[int64](f3)
var f3f4_2: int64 = f3_2 * safeConvert[int64](f4)
var f3f5_4: int64 = f3_2 * safeConvert[int64](f5_2)
var f3f6_2: int64 = f3_2 * safeConvert[int64](f6)
var f3f7_76: int64 = f3_2 * safeConvert[int64](f7_38)
var f3f8_38: int64 = f3_2 * safeConvert[int64](f8_19)
var f3f9_76: int64 = f3_2 * safeConvert[int64](f9_38)
var f4f4: int64 = f4 * safeConvert[int64](f4)
var f4f5_2: int64 = f4_2 * safeConvert[int64](f5)
var f4f6_38: int64 = f4_2 * safeConvert[int64](f6_19)
var f4f7_38: int64 = f4 * safeConvert[int64](f7_38)
var f4f8_38: int64 = f4_2 * safeConvert[int64](f8_19)
var f4f9_38: int64 = f4 * safeConvert[int64](f9_38)
var f5f5_38: int64 = f5 * safeConvert[int64](f5_38)
var f5f6_38: int64 = f5_2 * safeConvert[int64](f6_19)
var f5f7_76: int64 = f5_2 * safeConvert[int64](f7_38)
var f5f8_38: int64 = f5_2 * safeConvert[int64](f8_19)
var f5f9_76: int64 = f5_2 * safeConvert[int64](f9_38)
var f6f6_19: int64 = f6 * safeConvert[int64](f6_19)
var f6f7_38: int64 = f6 * safeConvert[int64](f7_38)
var f6f8_38: int64 = f6_2 * safeConvert[int64](f8_19)
var f6f9_38: int64 = f6 * safeConvert[int64](f9_38)
var f7f7_38: int64 = f7 * safeConvert[int64](f7_38)
var f7f8_38: int64 = f7_2 * safeConvert[int64](f8_19)
var f7f9_76: int64 = f7_2 * safeConvert[int64](f9_38)
var f8f8_19: int64 = f8 * safeConvert[int64](f8_19)
var f8f9_38: int64 = f8 * safeConvert[int64](f9_38)
var f9f9_38: int64 = f9 * safeConvert[int64](f9_38)
var h0: int64 = f0f0 + f1f9_76 + f2f8_38 + f3f7_76 + f4f6_38 + f5f5_38
var h1: int64 = f0f1_2 + f2f9_38 + f3f8_38 + f4f7_38 + f5f6_38
var h2: int64 = f0f2_2 + f1f1_2 + f3f9_76 + f4f8_38 + f5f7_76 + f6f6_19
@@ -623,61 +626,61 @@ proc feSq2(h: var Fe, f: Fe) =
var f7_38 = 38 * f7
var f8_19 = 19 * f8
var f9_38 = 38 * f9
var f0f0 = cast[int64](f0) * cast[int64](f0)
var f0f1_2 = cast[int64](f0_2) * cast[int64](f1)
var f0f2_2 = cast[int64](f0_2) * cast[int64](f2)
var f0f3_2 = cast[int64](f0_2) * cast[int64](f3)
var f0f4_2 = cast[int64](f0_2) * cast[int64](f4)
var f0f5_2 = cast[int64](f0_2) * cast[int64](f5)
var f0f6_2 = cast[int64](f0_2) * cast[int64](f6)
var f0f7_2 = cast[int64](f0_2) * cast[int64](f7)
var f0f8_2 = cast[int64](f0_2) * cast[int64](f8)
var f0f9_2 = cast[int64](f0_2) * cast[int64](f9)
var f1f1_2 = cast[int64](f1_2) * cast[int64](f1)
var f1f2_2 = cast[int64](f1_2) * cast[int64](f2)
var f1f3_4 = cast[int64](f1_2) * cast[int64](f3_2)
var f1f4_2 = cast[int64](f1_2) * cast[int64](f4)
var f1f5_4 = cast[int64](f1_2) * cast[int64](f5_2)
var f1f6_2 = cast[int64](f1_2) * cast[int64](f6)
var f1f7_4 = cast[int64](f1_2) * cast[int64](f7_2)
var f1f8_2 = cast[int64](f1_2) * cast[int64](f8)
var f1f9_76 = cast[int64](f1_2) * cast[int64](f9_38)
var f2f2 = cast[int64](f2) * cast[int64](f2)
var f2f3_2 = cast[int64](f2_2) * cast[int64](f3)
var f2f4_2 = cast[int64](f2_2) * cast[int64](f4)
var f2f5_2 = cast[int64](f2_2) * cast[int64](f5)
var f2f6_2 = cast[int64](f2_2) * cast[int64](f6)
var f2f7_2 = cast[int64](f2_2) * cast[int64](f7)
var f2f8_38 = cast[int64](f2_2) * cast[int64](f8_19)
var f2f9_38 = cast[int64](f2) * cast[int64](f9_38)
var f3f3_2 = cast[int64](f3_2) * cast[int64](f3)
var f3f4_2 = cast[int64](f3_2) * cast[int64](f4)
var f3f5_4 = cast[int64](f3_2) * cast[int64](f5_2)
var f3f6_2 = cast[int64](f3_2) * cast[int64](f6)
var f3f7_76 = cast[int64](f3_2) * cast[int64](f7_38)
var f3f8_38 = cast[int64](f3_2) * cast[int64](f8_19)
var f3f9_76 = cast[int64](f3_2) * cast[int64](f9_38)
var f4f4 = cast[int64](f4) * cast[int64](f4)
var f4f5_2 = cast[int64](f4_2) * cast[int64](f5)
var f4f6_38 = cast[int64](f4_2) * cast[int64](f6_19)
var f4f7_38 = cast[int64](f4) * cast[int64](f7_38)
var f4f8_38 = cast[int64](f4_2) * cast[int64](f8_19)
var f4f9_38 = cast[int64](f4) * cast[int64](f9_38)
var f5f5_38 = cast[int64](f5) * cast[int64](f5_38)
var f5f6_38 = cast[int64](f5_2) * cast[int64](f6_19)
var f5f7_76 = cast[int64](f5_2) * cast[int64](f7_38)
var f5f8_38 = cast[int64](f5_2) * cast[int64](f8_19)
var f5f9_76 = cast[int64](f5_2) * cast[int64](f9_38)
var f6f6_19 = cast[int64](f6) * cast[int64](f6_19)
var f6f7_38 = cast[int64](f6) * cast[int64](f7_38)
var f6f8_38 = cast[int64](f6_2) * cast[int64](f8_19)
var f6f9_38 = cast[int64](f6) * cast[int64](f9_38)
var f7f7_38 = cast[int64](f7) * cast[int64](f7_38)
var f7f8_38 = cast[int64](f7_2) * cast[int64](f8_19)
var f7f9_76 = cast[int64](f7_2) * cast[int64](f9_38)
var f8f8_19 = cast[int64](f8) * cast[int64](f8_19)
var f8f9_38 = cast[int64](f8) * cast[int64](f9_38)
var f9f9_38 = cast[int64](f9) * cast[int64](f9_38)
var f0f0 = safeConvert[int64](f0) * safeConvert[int64](f0)
var f0f1_2 = safeConvert[int64](f0_2) * safeConvert[int64](f1)
var f0f2_2 = safeConvert[int64](f0_2) * safeConvert[int64](f2)
var f0f3_2 = safeConvert[int64](f0_2) * safeConvert[int64](f3)
var f0f4_2 = safeConvert[int64](f0_2) * safeConvert[int64](f4)
var f0f5_2 = safeConvert[int64](f0_2) * safeConvert[int64](f5)
var f0f6_2 = safeConvert[int64](f0_2) * safeConvert[int64](f6)
var f0f7_2 = safeConvert[int64](f0_2) * safeConvert[int64](f7)
var f0f8_2 = safeConvert[int64](f0_2) * safeConvert[int64](f8)
var f0f9_2 = safeConvert[int64](f0_2) * safeConvert[int64](f9)
var f1f1_2 = safeConvert[int64](f1_2) * safeConvert[int64](f1)
var f1f2_2 = safeConvert[int64](f1_2) * safeConvert[int64](f2)
var f1f3_4 = safeConvert[int64](f1_2) * safeConvert[int64](f3_2)
var f1f4_2 = safeConvert[int64](f1_2) * safeConvert[int64](f4)
var f1f5_4 = safeConvert[int64](f1_2) * safeConvert[int64](f5_2)
var f1f6_2 = safeConvert[int64](f1_2) * safeConvert[int64](f6)
var f1f7_4 = safeConvert[int64](f1_2) * safeConvert[int64](f7_2)
var f1f8_2 = safeConvert[int64](f1_2) * safeConvert[int64](f8)
var f1f9_76 = safeConvert[int64](f1_2) * safeConvert[int64](f9_38)
var f2f2 = safeConvert[int64](f2) * safeConvert[int64](f2)
var f2f3_2 = safeConvert[int64](f2_2) * safeConvert[int64](f3)
var f2f4_2 = safeConvert[int64](f2_2) * safeConvert[int64](f4)
var f2f5_2 = safeConvert[int64](f2_2) * safeConvert[int64](f5)
var f2f6_2 = safeConvert[int64](f2_2) * safeConvert[int64](f6)
var f2f7_2 = safeConvert[int64](f2_2) * safeConvert[int64](f7)
var f2f8_38 = safeConvert[int64](f2_2) * safeConvert[int64](f8_19)
var f2f9_38 = safeConvert[int64](f2) * safeConvert[int64](f9_38)
var f3f3_2 = safeConvert[int64](f3_2) * safeConvert[int64](f3)
var f3f4_2 = safeConvert[int64](f3_2) * safeConvert[int64](f4)
var f3f5_4 = safeConvert[int64](f3_2) * safeConvert[int64](f5_2)
var f3f6_2 = safeConvert[int64](f3_2) * safeConvert[int64](f6)
var f3f7_76 = safeConvert[int64](f3_2) * safeConvert[int64](f7_38)
var f3f8_38 = safeConvert[int64](f3_2) * safeConvert[int64](f8_19)
var f3f9_76 = safeConvert[int64](f3_2) * safeConvert[int64](f9_38)
var f4f4 = safeConvert[int64](f4) * safeConvert[int64](f4)
var f4f5_2 = safeConvert[int64](f4_2) * safeConvert[int64](f5)
var f4f6_38 = safeConvert[int64](f4_2) * safeConvert[int64](f6_19)
var f4f7_38 = safeConvert[int64](f4) * safeConvert[int64](f7_38)
var f4f8_38 = safeConvert[int64](f4_2) * safeConvert[int64](f8_19)
var f4f9_38 = safeConvert[int64](f4) * safeConvert[int64](f9_38)
var f5f5_38 = safeConvert[int64](f5) * safeConvert[int64](f5_38)
var f5f6_38 = safeConvert[int64](f5_2) * safeConvert[int64](f6_19)
var f5f7_76 = safeConvert[int64](f5_2) * safeConvert[int64](f7_38)
var f5f8_38 = safeConvert[int64](f5_2) * safeConvert[int64](f8_19)
var f5f9_76 = safeConvert[int64](f5_2) * safeConvert[int64](f9_38)
var f6f6_19 = safeConvert[int64](f6) * safeConvert[int64](f6_19)
var f6f7_38 = safeConvert[int64](f6) * safeConvert[int64](f7_38)
var f6f8_38 = safeConvert[int64](f6_2) * safeConvert[int64](f8_19)
var f6f9_38 = safeConvert[int64](f6) * safeConvert[int64](f9_38)
var f7f7_38 = safeConvert[int64](f7) * safeConvert[int64](f7_38)
var f7f8_38 = safeConvert[int64](f7_2) * safeConvert[int64](f8_19)
var f7f9_76 = safeConvert[int64](f7_2) * safeConvert[int64](f9_38)
var f8f8_19 = safeConvert[int64](f8) * safeConvert[int64](f8_19)
var f8f9_38 = safeConvert[int64](f8) * safeConvert[int64](f9_38)
var f9f9_38 = safeConvert[int64](f9) * safeConvert[int64](f9_38)
var
c0, c1, c2, c3, c4, c5, c6, c7, c8, c9: int64
h0: int64 = f0f0 + f1f9_76 + f2f8_38 + f3f7_76 + f4f6_38 + f5f5_38
@@ -834,7 +837,7 @@ proc geFromBytesNegateVartime(h: var GeP3, s: openArray[byte]): int32 =
return -1;
feMul(h.x, h.x, SqrTm1)
if feIsNegative(h.x) == cast[int32](s[31] shr 7):
if feIsNegative(h.x) == safeConvert[int32](s[31] shr 7):
feNeg(h.x, h.x)
feMul(h.t, h.x, h.y)
@@ -956,14 +959,14 @@ proc equal(b, c: int8): byte =
var ub = cast[byte](b)
var uc = cast[byte](c)
var x = ub xor uc
var y = cast[uint32](x)
var y = safeConvert[uint32](x)
y = y - 1
y = y shr 31
result = cast[byte](y)
proc negative(b: int8): byte =
var x = cast[uint64](b)
x = x shr 63
var x = cast[uint8](b)
x = x shr 7
result = cast[byte](x)
proc cmov(t: var GePrecomp, u: GePrecomp, b: byte) =

View File

@@ -15,7 +15,7 @@ else:
{.push raises: [].}
import nimcrypto
import bearssl/[kdf, rand, hash]
import bearssl/[kdf, hash]
type HkdfResult*[len: static int] = array[len, byte]

View File

@@ -18,6 +18,7 @@ import stew/[endians2, results, ctops]
export results
# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
import ../utility
type
Asn1Error* {.pure.} = enum
@@ -119,7 +120,7 @@ template toOpenArray*(af: Asn1Field): untyped =
template isEmpty*(ab: Asn1Buffer): bool =
ab.offset >= len(ab.buffer)
template isEnough*(ab: Asn1Buffer, length: int): bool =
template isEnough*(ab: Asn1Buffer, length: int64): bool =
len(ab.buffer) >= ab.offset + length
proc len*[T: Asn1Buffer|Asn1Composite](abc: T): int {.inline.} =
@@ -344,32 +345,6 @@ proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openArray[byte],
dest[k - 1] = dest[k - 1] and 0x7F'u8
res
proc asn1EncodeOid*(dest: var openArray[byte], value: openArray[int]): int =
## Encode array of integers ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and
## return number of bytes (octets) used.
##
## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned.
var buffer: array[16, byte]
var res = 1
var oidlen = 1
for i in 2..<len(value):
oidlen += asn1EncodeTag(buffer, cast[uint64](value[i]))
res += asn1EncodeLength(buffer, uint64(oidlen))
res += oidlen
if len(dest) >= res:
let last = dest.high
var offset = 1
dest[0] = Asn1Tag.Oid.code()
offset += asn1EncodeLength(dest.toOpenArray(offset, last), uint64(oidlen))
dest[offset] = cast[byte](value[0] * 40 + value[1])
offset += 1
for i in 2..<len(value):
offset += asn1EncodeTag(dest.toOpenArray(offset, last),
cast[uint64](value[i]))
res
proc asn1EncodeOid*(dest: var openArray[byte], value: openArray[byte]): int =
## Encode array of bytes ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and return
## number of bytes (octets) used.
@@ -443,26 +418,29 @@ proc asn1EncodeContextTag*(dest: var openArray[byte], value: openArray[byte],
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc getLength(ab: var Asn1Buffer): Asn1Result[uint64] =
proc getLength(ab: var Asn1Buffer): Asn1Result[int] =
## Decode length part of ASN.1 TLV triplet.
if not ab.isEmpty():
let b = ab.buffer[ab.offset]
if (b and 0x80'u8) == 0x00'u8:
let length = cast[uint64](b)
let length = safeConvert[int](b)
ab.offset += 1
return ok(length)
if b == 0x80'u8:
return err(Asn1Error.Indefinite)
if b == 0xFF'u8:
return err(Asn1Error.Incorrect)
let octets = cast[uint64](b and 0x7F'u8)
if octets > 8'u64:
let octets = safeConvert[int](b and 0x7F'u8)
if octets > 8:
return err(Asn1Error.Overflow)
if ab.isEnough(int(octets)):
var length: uint64 = 0
for i in 0..<int(octets):
length = (length shl 8) or cast[uint64](ab.buffer[ab.offset + i + 1])
ab.offset = ab.offset + int(octets) + 1
if ab.isEnough(octets):
var lengthU: uint64 = 0
for i in 0..<octets:
lengthU = (lengthU shl 8) or safeConvert[uint64](ab.buffer[ab.offset + i + 1])
if lengthU > uint64(int64.high):
return err(Asn1Error.Overflow)
let length = int(lengthU)
ab.offset = ab.offset + octets + 1
return ok(length)
else:
return err(Asn1Error.Incomplete)
@@ -474,8 +452,8 @@ proc getTag(ab: var Asn1Buffer, tag: var int): Asn1Result[Asn1Class] =
if not ab.isEmpty():
let
b = ab.buffer[ab.offset]
c = int((b and 0xC0'u8) shr 6)
tag = int(b and 0x3F)
c = safeConvert[int]((b and 0xC0'u8) shr 6)
tag = safeConvert[int](b and 0x3F)
ab.offset += 1
if c >= 0 and c < 4:
ok(cast[Asn1Class](c))
@@ -489,7 +467,7 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
var
field: Asn1Field
tag, ttag, offset: int
length, tlength: uint64
length, tlength: int
aclass: Asn1Class
inclass: bool
@@ -519,7 +497,7 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
if length != 1:
return err(Asn1Error.Incorrect)
if not ab.isEnough(int(length)):
if not ab.isEnough(length):
return err(Asn1Error.Incomplete)
let b = ab.buffer[ab.offset]
@@ -527,7 +505,7 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
return err(Asn1Error.Incorrect)
field = Asn1Field(kind: Asn1Tag.Boolean, klass: aclass,
index: ttag, offset: int(ab.offset),
index: ttag, offset: ab.offset,
length: 1, buffer: ab.buffer)
field.vbool = (b == 0xFF'u8)
ab.offset += 1
@@ -538,12 +516,12 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
if length == 0:
return err(Asn1Error.Incorrect)
if not ab.isEnough(int(length)):
if not ab.isEnough(length):
return err(Asn1Error.Incomplete)
# Count number of leading zeroes
var zc = 0
while (zc < int(length)) and (ab.buffer[ab.offset + zc] == 0x00'u8):
while (zc < length) and (ab.buffer[ab.offset + zc] == 0x00'u8):
inc(zc)
if zc > 1:
@@ -552,45 +530,45 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
if zc == 0:
# Negative or Positive integer
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
index: ttag, offset: ab.offset,
length: length, buffer: ab.buffer)
if (ab.buffer[ab.offset] and 0x80'u8) == 0x80'u8:
# Negative integer
if length <= 8:
# We need this transformation because our field.vint is uint64.
for i in 0 ..< 8:
if i < 8 - int(length):
if i < 8 - length:
field.vint = (field.vint shl 8) or 0xFF'u64
else:
let offset = ab.offset + i - (8 - int(length))
field.vint = (field.vint shl 8) or uint64(ab.buffer[offset])
let offset = ab.offset + i - (8 - length)
field.vint = (field.vint shl 8) or safeConvert[uint64](ab.buffer[offset])
else:
# Positive integer
if length <= 8:
for i in 0 ..< int(length):
for i in 0 ..< length:
field.vint = (field.vint shl 8) or
uint64(ab.buffer[ab.offset + i])
ab.offset += int(length)
safeConvert[uint64](ab.buffer[ab.offset + i])
ab.offset += length
return ok(field)
else:
if length == 1:
# Zero value integer
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), vint: 0'u64,
index: ttag, offset: ab.offset,
length: length, vint: 0'u64,
buffer: ab.buffer)
ab.offset += int(length)
ab.offset += length
return ok(field)
else:
# Positive integer with leading zero
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset) + 1,
length: int(length) - 1, buffer: ab.buffer)
index: ttag, offset: ab.offset + 1,
length: length - 1, buffer: ab.buffer)
if length <= 9:
for i in 1 ..< int(length):
for i in 1 ..< length:
field.vint = (field.vint shl 8) or
uint64(ab.buffer[ab.offset + i])
ab.offset += int(length)
safeConvert[uint64](ab.buffer[ab.offset + i])
ab.offset += length
return ok(field)
of Asn1Tag.BitString.code():
@@ -606,13 +584,13 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
else:
# Zero-length BIT STRING.
field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass,
index: ttag, offset: int(ab.offset + 1),
index: ttag, offset: ab.offset + 1,
length: 0, ubits: 0, buffer: ab.buffer)
ab.offset += int(length)
ab.offset += length
return ok(field)
else:
if not ab.isEnough(int(length)):
if not ab.isEnough(length):
return err(Asn1Error.Incomplete)
let unused = ab.buffer[ab.offset]
@@ -620,27 +598,27 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
# Number of unused bits should not be bigger then `7`.
return err(Asn1Error.Incorrect)
let mask = (1'u8 shl int(unused)) - 1'u8
if (ab.buffer[ab.offset + int(length) - 1] and mask) != 0x00'u8:
let mask = (1'u8 shl safeConvert[int](unused)) - 1'u8
if (ab.buffer[ab.offset + length - 1] and mask) != 0x00'u8:
## All unused bits should be set to `0`.
return err(Asn1Error.Incorrect)
field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass,
index: ttag, offset: int(ab.offset + 1),
length: int(length - 1), ubits: int(unused),
index: ttag, offset: ab.offset + 1,
length: length - 1, ubits: safeConvert[int](unused),
buffer: ab.buffer)
ab.offset += int(length)
ab.offset += length
return ok(field)
of Asn1Tag.OctetString.code():
# OCTET STRING
if not ab.isEnough(int(length)):
if not ab.isEnough(length):
return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.OctetString, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
ab.offset += int(length)
index: ttag, offset: ab.offset,
length: length, buffer: ab.buffer)
ab.offset += length
return ok(field)
of Asn1Tag.Null.code():
@@ -649,30 +627,30 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
return err(Asn1Error.Incorrect)
field = Asn1Field(kind: Asn1Tag.Null, klass: aclass, index: ttag,
offset: int(ab.offset), length: 0, buffer: ab.buffer)
ab.offset += int(length)
offset: ab.offset, length: 0, buffer: ab.buffer)
ab.offset += length
return ok(field)
of Asn1Tag.Oid.code():
# OID
if not ab.isEnough(int(length)):
if not ab.isEnough(length):
return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.Oid, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
ab.offset += int(length)
index: ttag, offset: ab.offset,
length: length, buffer: ab.buffer)
ab.offset += length
return ok(field)
of Asn1Tag.Sequence.code():
# SEQUENCE
if not ab.isEnough(int(length)):
if not ab.isEnough(length):
return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.Sequence, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
ab.offset += int(length)
index: ttag, offset: ab.offset,
length: length, buffer: ab.buffer)
ab.offset += length
return ok(field)
else:

View File

@@ -686,7 +686,7 @@ proc `==`*(a, b: RsaPrivateKey): bool =
false
else:
if a.seck.nBitlen == b.seck.nBitlen:
if cast[int](a.seck.nBitlen) > 0:
if a.seck.nBitlen > 0'u:
let r1 = CT.isEqual(getArray(a.buffer, a.seck.p, a.seck.plen),
getArray(b.buffer, b.seck.p, b.seck.plen))
let r2 = CT.isEqual(getArray(a.buffer, a.seck.q, a.seck.qlen),

View File

@@ -35,9 +35,6 @@ type
SkSignature* = distinct secp256k1.SkSignature
SkKeyPair* = distinct secp256k1.SkKeyPair
template pubkey*(v: SkKeyPair): SkPublicKey = SkPublicKey(secp256k1.SkKeyPair(v).pubkey)
template seckey*(v: SkKeyPair): SkPrivateKey = SkPrivateKey(secp256k1.SkKeyPair(v).seckey)
proc random*(t: typedesc[SkPrivateKey], rng: var HmacDrbgContext): SkPrivateKey =
#TODO is there a better way?
var rngPtr = addr rng

View File

@@ -17,7 +17,7 @@ import std/[os, osproc, strutils, tables, strtabs, sequtils]
import pkg/[chronos, chronicles]
import ../varint, ../multiaddress, ../multicodec, ../cid, ../peerid
import ../wire, ../multihash, ../protobuf/minprotobuf, ../errors
import ../crypto/crypto
import ../crypto/crypto, ../utility
export
peerid, multiaddress, multicodec, multihash, cid, crypto, wire, errors
@@ -170,7 +170,7 @@ proc requestIdentity(): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
## Processing function `doIdentify(req *pb.Request)`.
result = initProtoBuffer({WithVarintLength})
result.write(1, cast[uint](RequestType.IDENTIFY))
result.write(1, safeConvert[uint](RequestType.IDENTIFY))
result.finish()
proc requestConnect(peerid: PeerId,
@@ -185,7 +185,7 @@ proc requestConnect(peerid: PeerId,
msg.write(2, item.data.buffer)
if timeout > 0:
msg.write(3, hint64(timeout))
result.write(1, cast[uint](RequestType.CONNECT))
result.write(1, safeConvert[uint](RequestType.CONNECT))
result.write(2, msg)
result.finish()
@@ -195,7 +195,7 @@ proc requestDisconnect(peerid: PeerId): ProtoBuffer =
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, peerid)
result.write(1, cast[uint](RequestType.DISCONNECT))
result.write(1, safeConvert[uint](RequestType.DISCONNECT))
result.write(7, msg)
result.finish()
@@ -211,7 +211,7 @@ proc requestStreamOpen(peerid: PeerId,
msg.write(2, item)
if timeout > 0:
msg.write(3, hint64(timeout))
result.write(1, cast[uint](RequestType.STREAM_OPEN))
result.write(1, safeConvert[uint](RequestType.STREAM_OPEN))
result.write(3, msg)
result.finish()
@@ -224,7 +224,7 @@ proc requestStreamHandler(address: MultiAddress,
msg.write(1, address.data.buffer)
for item in protocols:
msg.write(2, item)
result.write(1, cast[uint](RequestType.STREAM_HANDLER))
result.write(1, safeConvert[uint](RequestType.STREAM_HANDLER))
result.write(4, msg)
result.finish()
@@ -232,13 +232,13 @@ proc requestListPeers(): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
## Processing function `doListPeers(req *pb.Request)`
result = initProtoBuffer({WithVarintLength})
result.write(1, cast[uint](RequestType.LIST_PEERS))
result.write(1, safeConvert[uint](RequestType.LIST_PEERS))
result.finish()
proc requestDHTFindPeer(peer: PeerId, timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTFindPeer(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.FIND_PEER)
let msgid = safeConvert[uint](DHTRequestType.FIND_PEER)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -246,7 +246,7 @@ proc requestDHTFindPeer(peer: PeerId, timeout = 0): ProtoBuffer =
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
@@ -254,7 +254,7 @@ proc requestDHTFindPeersConnectedToPeer(peer: PeerId,
timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTFindPeersConnectedToPeer(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.FIND_PEERS_CONNECTED_TO_PEER)
let msgid = safeConvert[uint](DHTRequestType.FIND_PEERS_CONNECTED_TO_PEER)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -262,7 +262,7 @@ proc requestDHTFindPeersConnectedToPeer(peer: PeerId,
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
@@ -270,7 +270,7 @@ proc requestDHTFindProviders(cid: Cid,
count: uint32, timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTFindProviders(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.FIND_PROVIDERS)
let msgid = safeConvert[uint](DHTRequestType.FIND_PROVIDERS)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -279,14 +279,14 @@ proc requestDHTFindProviders(cid: Cid,
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
proc requestDHTGetClosestPeers(key: string, timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTGetClosestPeers(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.GET_CLOSEST_PEERS)
let msgid = safeConvert[uint](DHTRequestType.GET_CLOSEST_PEERS)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -294,14 +294,14 @@ proc requestDHTGetClosestPeers(key: string, timeout = 0): ProtoBuffer =
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
proc requestDHTGetPublicKey(peer: PeerId, timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTGetPublicKey(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.GET_PUBLIC_KEY)
let msgid = safeConvert[uint](DHTRequestType.GET_PUBLIC_KEY)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -309,14 +309,14 @@ proc requestDHTGetPublicKey(peer: PeerId, timeout = 0): ProtoBuffer =
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
proc requestDHTGetValue(key: string, timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTGetValue(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.GET_VALUE)
let msgid = safeConvert[uint](DHTRequestType.GET_VALUE)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -324,14 +324,14 @@ proc requestDHTGetValue(key: string, timeout = 0): ProtoBuffer =
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
proc requestDHTSearchValue(key: string, timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTSearchValue(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.SEARCH_VALUE)
let msgid = safeConvert[uint](DHTRequestType.SEARCH_VALUE)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -339,7 +339,7 @@ proc requestDHTSearchValue(key: string, timeout = 0): ProtoBuffer =
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
@@ -347,7 +347,7 @@ proc requestDHTPutValue(key: string, value: openArray[byte],
timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTPutValue(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.PUT_VALUE)
let msgid = safeConvert[uint](DHTRequestType.PUT_VALUE)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -356,14 +356,14 @@ proc requestDHTPutValue(key: string, value: openArray[byte],
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, uint(RequestType.DHT))
result.write(5, msg)
result.finish()
proc requestDHTProvide(cid: Cid, timeout = 0): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
## Processing function `doDHTProvide(req *pb.DHTRequest)`.
let msgid = cast[uint](DHTRequestType.PROVIDE)
let msgid = safeConvert[uint](DHTRequestType.PROVIDE)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -371,13 +371,13 @@ proc requestDHTProvide(cid: Cid, timeout = 0): ProtoBuffer =
if timeout > 0:
msg.write(7, hint64(timeout))
msg.finish()
result.write(1, cast[uint](RequestType.DHT))
result.write(1, safeConvert[uint](RequestType.DHT))
result.write(5, msg)
result.finish()
proc requestCMTagPeer(peer: PeerId, tag: string, weight: int): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L18
let msgid = cast[uint](ConnManagerRequestType.TAG_PEER)
let msgid = safeConvert[uint](ConnManagerRequestType.TAG_PEER)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
@@ -385,83 +385,83 @@ proc requestCMTagPeer(peer: PeerId, tag: string, weight: int): ProtoBuffer =
msg.write(3, tag)
msg.write(4, hint64(weight))
msg.finish()
result.write(1, cast[uint](RequestType.CONNMANAGER))
result.write(1, safeConvert[uint](RequestType.CONNMANAGER))
result.write(6, msg)
result.finish()
proc requestCMUntagPeer(peer: PeerId, tag: string): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L33
let msgid = cast[uint](ConnManagerRequestType.UNTAG_PEER)
let msgid = safeConvert[uint](ConnManagerRequestType.UNTAG_PEER)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
msg.write(2, peer)
msg.write(3, tag)
msg.finish()
result.write(1, cast[uint](RequestType.CONNMANAGER))
result.write(1, safeConvert[uint](RequestType.CONNMANAGER))
result.write(6, msg)
result.finish()
proc requestCMTrim(): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L47
let msgid = cast[uint](ConnManagerRequestType.TRIM)
let msgid = safeConvert[uint](ConnManagerRequestType.TRIM)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
msg.finish()
result.write(1, cast[uint](RequestType.CONNMANAGER))
result.write(1, safeConvert[uint](RequestType.CONNMANAGER))
result.write(6, msg)
result.finish()
proc requestPSGetTopics(): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/pubsub.go
## Processing function `doPubsubGetTopics(req *pb.PSRequest)`.
let msgid = cast[uint](PSRequestType.GET_TOPICS)
let msgid = safeConvert[uint](PSRequestType.GET_TOPICS)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
msg.finish()
result.write(1, cast[uint](RequestType.PUBSUB))
result.write(1, safeConvert[uint](RequestType.PUBSUB))
result.write(8, msg)
result.finish()
proc requestPSListPeers(topic: string): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/pubsub.go
## Processing function `doPubsubListPeers(req *pb.PSRequest)`.
let msgid = cast[uint](PSRequestType.LIST_PEERS)
let msgid = safeConvert[uint](PSRequestType.LIST_PEERS)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
msg.write(2, topic)
msg.finish()
result.write(1, cast[uint](RequestType.PUBSUB))
result.write(1, safeConvert[uint](RequestType.PUBSUB))
result.write(8, msg)
result.finish()
proc requestPSPublish(topic: string, data: openArray[byte]): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/pubsub.go
## Processing function `doPubsubPublish(req *pb.PSRequest)`.
let msgid = cast[uint](PSRequestType.PUBLISH)
let msgid = safeConvert[uint](PSRequestType.PUBLISH)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
msg.write(2, topic)
msg.write(3, data)
msg.finish()
result.write(1, cast[uint](RequestType.PUBSUB))
result.write(1, safeConvert[uint](RequestType.PUBSUB))
result.write(8, msg)
result.finish()
proc requestPSSubscribe(topic: string): ProtoBuffer =
## https://github.com/libp2p/go-libp2p-daemon/blob/master/pubsub.go
## Processing function `doPubsubSubscribe(req *pb.PSRequest)`.
let msgid = cast[uint](PSRequestType.SUBSCRIBE)
let msgid = safeConvert[uint](PSRequestType.SUBSCRIBE)
result = initProtoBuffer({WithVarintLength})
var msg = initProtoBuffer()
msg.write(1, msgid)
msg.write(2, topic)
msg.finish()
result.write(1, cast[uint](RequestType.PUBSUB))
result.write(1, safeConvert[uint](RequestType.PUBSUB))
result.write(8, msg)
result.finish()
@@ -515,7 +515,7 @@ proc socketExists(address: MultiAddress): Future[bool] {.async.} =
var transp = await connect(address)
await transp.closeWait()
result = true
except:
except CatchableError, Defect:
result = false
when defined(windows):
@@ -525,7 +525,7 @@ when defined(windows):
result = cast[int](getCurrentProcessId())
else:
proc getProcessId(): int =
result = cast[int](posix.getpid())
result = int(posix.getpid())
proc getSocket(pattern: string,
count: ptr int): Future[MultiAddress] {.async.} =
@@ -759,12 +759,8 @@ proc newDaemonApi*(flags: set[P2PDaemonFlags] = {},
# Starting daemon process
# echo "Starting ", cmd, " ", args.join(" ")
api.process =
try:
exceptionToAssert:
startProcess(cmd, "", args, env, {poParentStreams})
except CatchableError as exc:
raise exc
except Exception as exc:
raiseAssert exc.msg
# Waiting until daemon will not be bound to control socket.
while true:
if not api.process.running():
@@ -872,7 +868,7 @@ proc connect*(api: DaemonAPI, peer: PeerId,
timeout))
pb.withMessage() do:
discard
except:
except CatchableError, Defect:
await api.closeConnection(transp)
proc disconnect*(api: DaemonAPI, peer: PeerId) {.async.} =
@@ -1033,7 +1029,7 @@ proc enterDhtMessage(pb: ProtoBuffer, rt: DHTResponseType): ProtoBuffer
var dtype: uint
if pbDhtResponse.getRequiredField(1, dtype).isErr():
raise newException(DaemonLocalError, "Missing required DHT field `type`!")
if dtype != cast[uint](rt):
if dtype != safeConvert[uint](rt):
raise newException(DaemonLocalError, "Wrong DHT answer type! ")
var value: seq[byte]
@@ -1057,9 +1053,9 @@ proc getDhtMessageType(pb: ProtoBuffer): DHTResponseType
var dtype: uint
if pb.getRequiredField(1, dtype).isErr():
raise newException(DaemonLocalError, "Missing required DHT field `type`!")
if dtype == cast[uint](DHTResponseType.VALUE):
if dtype == safeConvert[uint](DHTResponseType.VALUE):
result = DHTResponseType.VALUE
elif dtype == cast[uint](DHTResponseType.END):
elif dtype == safeConvert[uint](DHTResponseType.END):
result = DHTResponseType.END
else:
raise newException(DaemonLocalError, "Wrong DHT answer type!")

View File

@@ -28,7 +28,8 @@ method connect*(
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true) {.async, base.} =
reuseConnection = true,
upgradeDir = Direction.Out) {.async, base.} =
## connect remote peer without negotiating
## a protocol
##

View File

@@ -7,7 +7,7 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/[sugar, tables, sequtils]
import std/tables
import stew/results
import pkg/[chronos,
@@ -17,7 +17,9 @@ import pkg/[chronos,
import dial,
peerid,
peerinfo,
peerstore,
multicodec,
muxers/muxer,
multistream,
connmanager,
stream/connection,
@@ -34,25 +36,25 @@ logScope:
declareCounter(libp2p_total_dial_attempts, "total attempted dials")
declareCounter(libp2p_successful_dials, "dialed successful peers")
declareCounter(libp2p_failed_dials, "failed dials")
declareCounter(libp2p_failed_upgrades_outgoing, "outgoing connections failed upgrades")
type
DialFailedError* = object of LPError
Dialer* = ref object of Dial
localPeerId*: PeerId
ms: MultistreamSelect
connManager: ConnManager
dialLock: Table[PeerId, AsyncLock]
transports: seq[Transport]
peerStore: PeerStore
nameResolver: NameResolver
proc dialAndUpgrade(
self: Dialer,
peerId: Opt[PeerId],
hostname: string,
address: MultiAddress):
Future[Connection] {.async.} =
address: MultiAddress,
upgradeDir = Direction.Out):
Future[Muxer] {.async.} =
for transport in self.transports: # for each transport
if transport.handles(address): # check if it can dial it
@@ -69,29 +71,29 @@ proc dialAndUpgrade(
libp2p_failed_dials.inc()
return nil # Try the next address
# also keep track of the connection's bottom unsafe transport direction
# required by gossipsub scoring
dialed.transportDir = Direction.Out
libp2p_successful_dials.inc()
let conn =
let mux =
try:
await transport.upgradeOutgoing(dialed, peerId)
dialed.transportDir = upgradeDir
await transport.upgrade(dialed, upgradeDir, peerId)
except CatchableError as exc:
# If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again
await dialed.close()
debug "Upgrade failed", msg = exc.msg, peerId
if exc isnot CancelledError:
libp2p_failed_upgrades_outgoing.inc()
if upgradeDir == Direction.Out:
libp2p_failed_upgrades_outgoing.inc()
else:
libp2p_failed_upgrades_incoming.inc()
# Try other address
return nil
doAssert not isNil(conn), "connection died after upgradeOutgoing"
debug "Dial successful", conn, peerId = conn.peerId
return conn
doAssert not isNil(mux), "connection died after upgrade " & $upgradeDir
debug "Dial successful", peerId = mux.connection.peerId
return mux
return nil
proc expandDnsAddr(
@@ -125,8 +127,9 @@ proc expandDnsAddr(
proc dialAndUpgrade(
self: Dialer,
peerId: Opt[PeerId],
addrs: seq[MultiAddress]):
Future[Connection] {.async.} =
addrs: seq[MultiAddress],
upgradeDir = Direction.Out):
Future[Muxer] {.async.} =
debug "Dialing peer", peerId
@@ -143,33 +146,26 @@ proc dialAndUpgrade(
else: await self.nameResolver.resolveMAddress(expandedAddress)
for resolvedAddress in resolvedAddresses:
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress)
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, upgradeDir)
if not isNil(result):
return result
proc tryReusingConnection(self: Dialer, peerId: PeerId): Future[Opt[Connection]] {.async.} =
var conn = self.connManager.selectConn(peerId)
if conn == nil:
return Opt.none(Connection)
proc tryReusingConnection(self: Dialer, peerId: PeerId): Future[Opt[Muxer]] {.async.} =
let muxer = self.connManager.selectMuxer(peerId)
if muxer == nil:
return Opt.none(Muxer)
if conn.atEof or conn.closed:
# This connection should already have been removed from the connection
# manager - it's essentially a bug that we end up here - we'll fail
# for now, hoping that this will clean themselves up later...
warn "dead connection in connection manager", conn
await conn.close()
raise newException(DialFailedError, "Zombie connection encountered")
trace "Reusing existing connection", conn, direction = $conn.dir
return Opt.some(conn)
trace "Reusing existing connection", muxer, direction = $muxer.connection.dir
return Opt.some(muxer)
proc internalConnect(
self: Dialer,
peerId: Opt[PeerId],
addrs: seq[MultiAddress],
forceDial: bool,
reuseConnection = true):
Future[Connection] {.async.} =
reuseConnection = true,
upgradeDir = Direction.Out):
Future[Muxer] {.async.} =
if Opt.some(self.localPeerId) == peerId:
raise newException(CatchableError, "can't dial self!")
@@ -179,32 +175,30 @@ proc internalConnect(
await lock.acquire()
if peerId.isSome and reuseConnection:
let connOpt = await self.tryReusingConnection(peerId.get())
if connOpt.isSome:
return connOpt.get()
let muxOpt = await self.tryReusingConnection(peerId.get())
if muxOpt.isSome:
return muxOpt.get()
let slot = self.connManager.getOutgoingSlot(forceDial)
let conn =
let muxed =
try:
await self.dialAndUpgrade(peerId, addrs)
await self.dialAndUpgrade(peerId, addrs, upgradeDir)
except CatchableError as exc:
slot.release()
raise exc
slot.trackConnection(conn)
if isNil(conn): # None of the addresses connected
slot.trackMuxer(muxed)
if isNil(muxed): # None of the addresses connected
raise newException(DialFailedError, "Unable to establish outgoing link")
# A disconnect could have happened right after
# we've added the connection so we check again
# to prevent races due to that.
if conn.closed() or conn.atEof():
# This can happen when the other ends drops us
# before we get a chance to return the connection
# back to the dialer.
trace "Connection dead on arrival", conn
raise newLPStreamClosedError()
try:
self.connManager.storeMuxer(muxed)
await self.peerStore.identify(muxed)
except CatchableError as exc:
trace "Failed to finish outgoung upgrade", err=exc.msg
await muxed.close()
raise exc
return conn
return muxed
finally:
if lock.locked():
lock.release()
@@ -214,7 +208,8 @@ method connect*(
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true) {.async.} =
reuseConnection = true,
upgradeDir = Direction.Out) {.async.} =
## connect remote peer without negotiating
## a protocol
##
@@ -222,7 +217,7 @@ method connect*(
if self.connManager.connCount(peerId) > 0 and reuseConnection:
return
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection)
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, upgradeDir)
method connect*(
self: Dialer,
@@ -235,21 +230,21 @@ method connect*(
return (await self.internalConnect(
Opt.some(fullAddress.get()[0]),
@[fullAddress.get()[1]],
false)).peerId
false)).connection.peerId
else:
if allowUnknownPeerId == false:
raise newException(DialFailedError, "Address without PeerID and unknown peer id disabled!")
return (await self.internalConnect(
Opt.none(PeerId),
@[address],
false)).peerId
false)).connection.peerId
proc negotiateStream(
self: Dialer,
conn: Connection,
protos: seq[string]): Future[Connection] {.async.} =
trace "Negotiating stream", conn, protos
let selected = await self.ms.select(conn, protos)
let selected = await MultistreamSelect.select(conn, protos)
if not protos.contains(selected):
await conn.closeWithEOF()
raise newException(DialFailedError, "Unable to select sub-protocol " & $protos)
@@ -267,11 +262,11 @@ method tryDial*(
trace "Check if it can dial", peerId, addrs
try:
let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs)
if conn.isNil():
let mux = await self.dialAndUpgrade(Opt.some(peerId), addrs)
if mux.isNil():
raise newException(DialFailedError, "No valid multiaddress")
await conn.close()
return conn.observedAddr
await mux.close()
return mux.connection.observedAddr
except CancelledError as exc:
raise exc
except CatchableError as exc:
@@ -303,7 +298,7 @@ method dial*(
##
var
conn: Connection
conn: Muxer
stream: Connection
proc cleanup() {.async.} =
@@ -340,12 +335,12 @@ proc new*(
T: type Dialer,
localPeerId: PeerId,
connManager: ConnManager,
peerStore: PeerStore,
transports: seq[Transport],
ms: MultistreamSelect,
nameResolver: NameResolver = nil): Dialer =
T(localPeerId: localPeerId,
connManager: connManager,
transports: transports,
ms: ms,
peerStore: peerStore,
nameResolver: nameResolver)

View File

@@ -12,7 +12,6 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import sequtils
import chronos
import ./discoverymngr,
../protocols/rendezvous,

View File

@@ -15,7 +15,7 @@ else:
{.push raises: [].}
{.push public.}
import pkg/chronos
import pkg/chronos, chronicles
import std/[nativesockets, hashes]
import tables, strutils, sets, stew/shims/net
import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
@@ -23,6 +23,9 @@ import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
import stew/[base58, base32, endians2, results]
export results, minprotobuf, vbuffer, utility
logScope:
topics = "libp2p multiaddress"
type
MAKind* = enum
None, Fixed, Length, Path, Marker
@@ -34,7 +37,7 @@ type
coder*: Transcoder
MultiAddress* = object
data*: VBuffer
data: VBuffer
MaPatternOp* = enum
Eq, Or, And
@@ -63,6 +66,10 @@ const
IPPROTO_TCP = Protocol.IPPROTO_TCP
IPPROTO_UDP = Protocol.IPPROTO_UDP
proc data*(ma: MultiAddress): VBuffer =
## Returns the data buffer of the MultiAddress.
return ma.data
proc hash*(a: MultiAddress): Hash =
var h: Hash = 0
h = h !& hash(a.data.buffer)
@@ -76,7 +83,7 @@ proc ip4StB(s: string, vb: var VBuffer): bool =
if a.family == IpAddressFamily.IPv4:
vb.writeArray(a.address_v4)
result = true
except:
except CatchableError:
discard
proc ip4BtS(vb: var VBuffer, s: var string): bool =
@@ -99,7 +106,7 @@ proc ip6StB(s: string, vb: var VBuffer): bool =
if a.family == IpAddressFamily.IPv6:
vb.writeArray(a.address_v6)
result = true
except:
except CatchableError:
discard
proc ip6BtS(vb: var VBuffer, s: var string): bool =
@@ -143,14 +150,14 @@ proc portStB(s: string, vb: var VBuffer): bool =
port[1] = cast[byte](nport and 0xFF)
vb.writeArray(port)
result = true
except:
except CatchableError:
discard
proc portBtS(vb: var VBuffer, s: var string): bool =
## Port number bufferToString() implementation.
var port: array[2, byte]
if vb.readArray(port) == 2:
var nport = (cast[uint16](port[0]) shl 8) or cast[uint16](port[1])
var nport = (safeConvert[uint16](port[0]) shl 8) or safeConvert[uint16](port[1])
s = $nport
result = true
@@ -168,7 +175,7 @@ proc p2pStB(s: string, vb: var VBuffer): bool =
if MultiHash.decode(data, mh).isOk:
vb.writeSeq(data)
result = true
except:
except CatchableError:
discard
proc p2pBtS(vb: var VBuffer, s: var string): bool =
@@ -203,14 +210,14 @@ proc onionStB(s: string, vb: var VBuffer): bool =
address[11] = cast[byte](nport and 0xFF)
vb.writeArray(address)
result = true
except:
except CatchableError:
discard
proc onionBtS(vb: var VBuffer, s: var string): bool =
## ONION address bufferToString() implementation.
var buf: array[12, byte]
if vb.readArray(buf) == 12:
var nport = (cast[uint16](buf[10]) shl 8) or cast[uint16](buf[11])
var nport = (safeConvert[uint16](buf[10]) shl 8) or safeConvert[uint16](buf[11])
s = Base32Lower.encode(buf.toOpenArray(0, 9))
s.add(":")
s.add($nport)
@@ -237,14 +244,14 @@ proc onion3StB(s: string, vb: var VBuffer): bool =
address[36] = cast[byte](nport and 0xFF)
vb.writeArray(address)
result = true
except:
except CatchableError:
discard
proc onion3BtS(vb: var VBuffer, s: var string): bool =
## ONION address bufferToString() implementation.
var buf: array[37, byte]
if vb.readArray(buf) == 37:
var nport = (cast[uint16](buf[35]) shl 8) or cast[uint16](buf[36])
var nport = (safeConvert[uint16](buf[35]) shl 8) or safeConvert[uint16](buf[36])
s = Base32Lower.encode(buf.toOpenArray(0, 34))
s.add(":")
s.add($nport)
@@ -463,13 +470,23 @@ const
DNS* = mapOr(DNSANY, DNS4, DNS6, DNSADDR)
IP* = mapOr(IP4, IP6)
DNS_OR_IP* = mapOr(DNS, IP)
TCP* = mapOr(mapAnd(DNS, mapEq("tcp")), mapAnd(IP, mapEq("tcp")))
UDP* = mapOr(mapAnd(DNS, mapEq("udp")), mapAnd(IP, mapEq("udp")))
TCP_DNS* = mapAnd(DNS, mapEq("tcp"))
TCP_IP* =mapAnd(IP, mapEq("tcp"))
TCP* = mapOr(TCP_DNS, TCP_IP)
UDP_DNS* = mapAnd(DNS, mapEq("udp"))
UDP_IP* = mapAnd(IP, mapEq("udp"))
UDP* = mapOr(UDP_DNS, UDP_IP)
UTP* = mapAnd(UDP, mapEq("utp"))
QUIC* = mapAnd(UDP, mapEq("quic"))
UNIX* = mapEq("unix")
WS_DNS* = mapAnd(TCP_DNS, mapEq("ws"))
WS_IP* = mapAnd(TCP_IP, mapEq("ws"))
WS* = mapAnd(TCP, mapEq("ws"))
WSS_DNS* = mapAnd(TCP_DNS, mapEq("wss"))
WSS_IP* = mapAnd(TCP_IP, mapEq("wss"))
WSS* = mapAnd(TCP, mapEq("wss"))
WebSockets_DNS* = mapOr(WS_DNS, WSS_DNS)
WebSockets_IP* = mapOr(WS_IP, WSS_IP)
WebSockets* = mapOr(WS, WSS)
Onion3* = mapEq("onion3")
TcpOnion3* = mapAnd(TCP, Onion3)
@@ -873,6 +890,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} =
proc init*(mtype: typedesc[MultiAddress],
value: string): MaResult[MultiAddress] =
## Initialize MultiAddress object from string representation ``value``.
if len(value) == 0 or value == "/":
return err("multiaddress: Address must not be empty!")
var parts = value.trimRight('/').split('/')
if len(parts[0]) != 0:
err("multiaddress: Invalid MultiAddress, must start with `/`")
@@ -920,7 +939,7 @@ proc init*(mtype: typedesc[MultiAddress],
data: openArray[byte]): MaResult[MultiAddress] =
## Initialize MultiAddress with array of bytes ``data``.
if len(data) == 0:
err("multiaddress: Address could not be empty!")
err("multiaddress: Address must not be empty!")
else:
var res: MultiAddress
res.data = initVBuffer()
@@ -1111,6 +1130,10 @@ proc getField*(pb: ProtoBuffer, field: int,
proc getRepeatedField*(pb: ProtoBuffer, field: int,
value: var seq[MultiAddress]): ProtoResult[bool] {.
inline.} =
## Read repeated field from protobuf message. ``field`` is field number. If the message is malformed, an error is returned.
## If field is not present in message, then ``ok(false)`` is returned and value is empty. If field is present,
## but no items could be parsed, then ``err(ProtoError.IncorrectBlob)`` is returned and value is empty.
## If field is present and some item could be parsed, then ``true`` is returned and value contains the parsed values.
var items: seq[seq[byte]]
value.setLen(0)
let res = ? pb.getRepeatedField(field, items)
@@ -1122,6 +1145,8 @@ proc getRepeatedField*(pb: ProtoBuffer, field: int,
if ma.isOk():
value.add(ma.get())
else:
value.setLen(0)
return err(ProtoError.IncorrectBlob)
ok(true)
debug "Not supported MultiAddress in blob", ma = item
if value.len == 0:
err(ProtoError.IncorrectBlob)
else:
ok(true)

View File

@@ -19,8 +19,6 @@ import varint, vbuffer
import stew/results
export results
{.deadCodeElim: on.}
## List of officially supported codecs can BE found here
## https://github.com/multiformats/multicodec/blob/master/table.csv
const MultiCodecList = [

View File

@@ -21,12 +21,11 @@ logScope:
topics = "libp2p multistream"
const
MsgSize* = 1024
Codec* = "/multistream/1.0.0"
MsgSize = 1024
Codec = "/multistream/1.0.0"
MSCodec* = "\x13" & Codec & "\n"
Na* = "\x03na\n"
Ls* = "\x03ls\n"
Na = "na\n"
Ls = "ls\n"
type
Matcher* = proc (proto: string): bool {.gcsafe, raises: [Defect].}
@@ -45,7 +44,7 @@ type
proc new*(T: typedesc[MultistreamSelect]): T =
T(
codec: MSCodec,
codec: Codec,
)
template validateSuffix(str: string): untyped =
@@ -54,13 +53,13 @@ template validateSuffix(str: string): untyped =
else:
raise newException(MultiStreamError, "MultistreamSelect failed, malformed message")
proc select*(m: MultistreamSelect,
proc select*(_: MultistreamSelect | type MultistreamSelect,
conn: Connection,
proto: seq[string]):
Future[string] {.async.} =
trace "initiating handshake", conn, codec = m.codec
trace "initiating handshake", conn, codec = Codec
## select a remote protocol
await conn.write(m.codec) # write handshake
await conn.writeLp(Codec & "\n") # write handshake
if proto.len() > 0:
trace "selecting proto", conn, proto = proto[0]
await conn.writeLp((proto[0] & "\n")) # select proto
@@ -102,13 +101,13 @@ proc select*(m: MultistreamSelect,
# No alternatives, fail
return ""
proc select*(m: MultistreamSelect,
proc select*(_: MultistreamSelect | type MultistreamSelect,
conn: Connection,
proto: string): Future[bool] {.async.} =
if proto.len > 0:
return (await m.select(conn, @[proto])) == proto
return (await MultistreamSelect.select(conn, @[proto])) == proto
else:
return (await m.select(conn, @[])) == Codec
return (await MultistreamSelect.select(conn, @[])) == Codec
proc select*(m: MultistreamSelect, conn: Connection): Future[bool] =
m.select(conn, "")
@@ -119,7 +118,7 @@ proc list*(m: MultistreamSelect,
if not await m.select(conn):
return
await conn.write(Ls) # send ls
await conn.writeLp(Ls) # send ls
var list = newSeq[string]()
let ms = string.fromBytes(await conn.readLp(MsgSize))
@@ -129,68 +128,86 @@ proc list*(m: MultistreamSelect,
result = list
proc handle*(
_: type MultistreamSelect,
conn: Connection,
protos: seq[string],
matchers = newSeq[Matcher](),
active: bool = false,
): Future[string] {.async, gcsafe.} =
trace "Starting multistream negotiation", conn, handshaked = active
var handshaked = active
while not conn.atEof:
var ms = string.fromBytes(await conn.readLp(MsgSize))
validateSuffix(ms)
if not handshaked and ms != Codec:
debug "expected handshake message", conn, instead=ms
raise newException(CatchableError,
"MultistreamSelect handling failed, invalid first message")
trace "handle: got request", conn, ms
if ms.len() <= 0:
trace "handle: invalid proto", conn
await conn.writeLp(Na)
case ms:
of "ls":
trace "handle: listing protos", conn
#TODO this doens't seem to follow spec, each protocol
# should be length prefixed. Not very important
# since LS is getting deprecated
await conn.writeLp(protos.join("\n") & "\n")
of Codec:
if not handshaked:
await conn.writeLp(Codec & "\n")
handshaked = true
else:
trace "handle: sending `na` for duplicate handshake while handshaked",
conn
await conn.writeLp(Na)
elif ms in protos or matchers.anyIt(it(ms)):
trace "found handler", conn, protocol = ms
await conn.writeLp(ms & "\n")
conn.protocol = ms
return ms
else:
trace "no handlers", conn, protocol = ms
await conn.writeLp(Na)
proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.async, gcsafe.} =
trace "Starting multistream handler", conn, handshaked = active
var handshaked = active
var
handshaked = active
protos: seq[string]
matchers: seq[Matcher]
for h in m.handlers:
if not isNil(h.match):
matchers.add(h.match)
for proto in h.protos:
protos.add(proto)
try:
while not conn.atEof:
var ms = string.fromBytes(await conn.readLp(MsgSize))
validateSuffix(ms)
let ms = await MultistreamSelect.handle(conn, protos, matchers, active)
for h in m.handlers:
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
trace "found handler", conn, protocol = ms
if not handshaked and ms != Codec:
notice "expected handshake message", conn, instead=ms
raise newException(CatchableError,
"MultistreamSelect handling failed, invalid first message")
trace "handle: got request", conn, ms
if ms.len() <= 0:
trace "handle: invalid proto", conn
await conn.write(Na)
if m.handlers.len() == 0:
trace "handle: sending `na` for protocol", conn, protocol = ms
await conn.write(Na)
continue
case ms:
of "ls":
trace "handle: listing protos", conn
var protos = ""
for h in m.handlers:
for proto in h.protos:
protos &= (proto & "\n")
await conn.writeLp(protos)
of Codec:
if not handshaked:
await conn.write(m.codec)
handshaked = true
else:
trace "handle: sending `na` for duplicate handshake while handshaked",
conn
await conn.write(Na)
else:
for h in m.handlers:
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
trace "found handler", conn, protocol = ms
var protocolHolder = h
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
debug "Max streams for protocol reached, blocking new stream",
conn, protocol = ms, maxIncomingStreams
return
protocolHolder.openedStreams.inc(conn.peerId)
try:
await conn.writeLp(ms & "\n")
conn.protocol = ms
await protocolHolder.protocol.handler(conn, ms)
finally:
protocolHolder.openedStreams.inc(conn.peerId, -1)
if protocolHolder.openedStreams[conn.peerId] == 0:
protocolHolder.openedStreams.del(conn.peerId)
return
debug "no handlers", conn, protocol = ms
await conn.write(Na)
var protocolHolder = h
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
debug "Max streams for protocol reached, blocking new stream",
conn, protocol = ms, maxIncomingStreams
return
protocolHolder.openedStreams.inc(conn.peerId)
try:
await protocolHolder.protocol.handler(conn, ms)
finally:
protocolHolder.openedStreams.inc(conn.peerId, -1)
if protocolHolder.openedStreams[conn.peerId] == 0:
protocolHolder.openedStreams.del(conn.peerId)
return
debug "no handlers", conn, ms
except CancelledError as exc:
raise exc
except CatchableError as exc:

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import pkg/[chronos, nimcrypto/utils, chronicles, stew/byteutils]
import pkg/[chronos, chronicles, stew/byteutils]
import ../../stream/connection,
../../utility,
../../varint,

View File

@@ -13,7 +13,7 @@ else:
{.push raises: [].}
import std/[oids, strformat]
import pkg/[chronos, chronicles, metrics, nimcrypto/utils]
import pkg/[chronos, chronicles, metrics]
import ./coder,
../muxer,
../../stream/[bufferstream, connection, streamseq],
@@ -236,7 +236,7 @@ proc completeWrite(
else:
await fut
when defined(libp2p_network_protocol_metrics):
when defined(libp2p_network_protocols_metrics):
if s.protocol.len > 0:
libp2p_protocols_bytes.inc(msgLen.int64, labelValues=[s.protocol, "out"])

View File

@@ -203,7 +203,8 @@ method handle*(m: Mplex) {.async, gcsafe.} =
proc new*(M: type Mplex,
conn: Connection,
inTimeout, outTimeout: Duration = DefaultChanTimeout,
inTimeout: Duration = DefaultChanTimeout,
outTimeout: Duration = DefaultChanTimeout,
maxChannCount: int = MaxChannelCount): Mplex =
M(connection: conn,
inChannTimeout: inTimeout,
@@ -248,3 +249,7 @@ method close*(m: Mplex) {.async, gcsafe.} =
m.channels[true].clear()
trace "Closed mplex", m
method getStreams*(m: Mplex): seq[Connection] =
for c in m.channels[false].values: result.add(c)
for c in m.channels[true].values: result.add(c)

View File

@@ -13,8 +13,7 @@ else:
{.push raises: [].}
import chronos, chronicles
import ../protocols/protocol,
../stream/connection,
import ../stream/connection,
../errors
logScope:
@@ -32,24 +31,28 @@ type
Muxer* = ref object of RootObj
streamHandler*: StreamHandler
handler*: Future[void]
connection*: Connection
# user provider proc that returns a constructed Muxer
MuxerConstructor* = proc(conn: Connection): Muxer {.gcsafe, closure, raises: [Defect].}
# this wraps a creator proc that knows how to make muxers
MuxerProvider* = ref object of LPProtocol
MuxerProvider* = object
newMuxer*: MuxerConstructor
streamHandler*: StreamHandler # triggered every time there is a new stream, called for any muxer instance
muxerHandler*: MuxerHandler # triggered every time there is a new muxed connection created
codec*: string
func shortLog*(m: Muxer): auto = shortLog(m.connection)
func shortLog*(m: Muxer): auto =
if isNil(m): "nil"
else: shortLog(m.connection)
chronicles.formatIt(Muxer): shortLog(it)
# muxer interface
method newStream*(m: Muxer, name: string = "", lazy: bool = false):
Future[Connection] {.base, async, gcsafe.} = discard
method close*(m: Muxer) {.base, async, gcsafe.} = discard
method close*(m: Muxer) {.base, async, gcsafe.} =
if not isNil(m.connection):
await m.connection.close()
method handle*(m: Muxer): Future[void] {.base, async, gcsafe.} = discard
proc new*(
@@ -57,36 +60,7 @@ proc new*(
creator: MuxerConstructor,
codec: string): T {.gcsafe.} =
let muxerProvider = T(newMuxer: creator)
muxerProvider.codec = codec
muxerProvider.init()
let muxerProvider = T(newMuxer: creator, codec: codec)
muxerProvider
method init(c: MuxerProvider) =
proc handler(conn: Connection, proto: string) {.async, gcsafe, closure.} =
trace "starting muxer handler", proto=proto, conn
try:
let
muxer = c.newMuxer(conn)
if not isNil(c.streamHandler):
muxer.streamHandler = c.streamHandler
var futs = newSeq[Future[void]]()
futs &= muxer.handle()
# finally await both the futures
if not isNil(c.muxerHandler):
await c.muxerHandler(muxer)
when defined(libp2p_agents_metrics):
conn.shortAgent = muxer.connection.shortAgent
checkFutures(await allFinished(futs))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception in muxer handler", exc = exc.msg, conn, proto
finally:
await conn.close()
c.handler = handler
method getStreams*(m: Muxer): seq[Connection] {.base.} = doAssert false, "not implemented"

View File

@@ -84,7 +84,7 @@ proc `$`(header: YamuxHeader): string =
proc encode(header: YamuxHeader): array[12, byte] =
result[0] = header.version
result[1] = uint8(header.msgType)
result[2..3] = toBytesBE(cast[uint16](header.flags))
result[2..3] = toBytesBE(uint16(cast[uint8](header.flags))) # workaround https://github.com/nim-lang/Nim/issues/21789
result[4..7] = toBytesBE(header.streamId)
result[8..11] = toBytesBE(header.length)
@@ -356,6 +356,8 @@ proc open*(channel: YamuxChannel) {.async, gcsafe.} =
channel.opened = true
await channel.conn.write(YamuxHeader.data(channel.id, 0, {if channel.isSrc: Syn else: Ack}))
method getWrapped*(channel: YamuxChannel): Connection = channel.conn
type
Yamux* = ref object of Muxer
channels: Table[uint32, YamuxChannel]
@@ -506,6 +508,9 @@ method handle*(m: Yamux) {.async, gcsafe.} =
await m.close()
trace "Stopped yamux handler"
method getStreams*(m: Yamux): seq[Connection] =
for c in m.channels.values: result.add(c)
method newStream*(
m: Yamux,
name: string = "",

View File

@@ -15,7 +15,8 @@ else:
import
std/[streams, strutils, sets, sequtils],
chronos, chronicles, stew/byteutils,
dnsclientpkg/[protocol, types]
dnsclientpkg/[protocol, types],
../utility
import
nameresolver
@@ -80,9 +81,7 @@ proc getDnsResponse(
# parseResponse can has a raises: [Exception, ..] because of
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
return parseResponse(string.fromBytes(rawResponse))
except CatchableError as exc: raise exc
except Exception as exc: raiseAssert exc.msg
return exceptionToAssert: parseResponse(string.fromBytes(rawResponse))
finally:
await sock.closeWait()
@@ -118,9 +117,7 @@ method resolveIp*(
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
resolvedAddresses.incl(
try: answer.toString()
except CatchableError as exc: raise exc
except Exception as exc: raiseAssert exc.msg
exceptionToAssert(answer.toString())
)
except CancelledError as e:
raise e
@@ -151,9 +148,13 @@ method resolveTxt*(
for _ in 0 ..< self.nameServers.len:
let server = self.nameServers[0]
try:
# toString can has a raises: [Exception, ..] because of
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
let response = await getDnsResponse(server, address, TXT)
trace "Got TXT response", server = $server, answer=response.answers.mapIt(it.toString())
return response.answers.mapIt(it.toString())
return exceptionToAssert:
trace "Got TXT response", server = $server, answer=response.answers.mapIt(it.toString())
response.answers.mapIt(it.toString())
except CancelledError as e:
raise e
except CatchableError as e:
@@ -161,11 +162,6 @@ method resolveTxt*(
self.nameServers.add(self.nameServers[0])
self.nameServers.delete(0)
continue
except Exception as e:
# toString can has a raises: [Exception, ..] because of
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
raiseAssert e.msg
debug "Failed to resolve TXT, returning empty set"
return @[]

View File

@@ -13,7 +13,7 @@ else:
{.push raises: [].}
import
std/[streams, strutils, tables],
std/tables,
chronos, chronicles
import nameresolver

View File

@@ -16,7 +16,7 @@ import std/[sugar, sets, sequtils, strutils]
import
chronos,
chronicles,
stew/[endians2, byteutils]
stew/endians2
import ".."/[multiaddress, multicodec]
logScope:

View File

@@ -0,0 +1,98 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import
std/[sequtils, tables],
chronos, chronicles,
multiaddress, multicodec
type
## Manages observed MultiAddresses by reomte peers. It keeps track of the most observed IP and IP/Port.
ObservedAddrManager* = ref object of RootObj
observedIPsAndPorts: seq[MultiAddress]
maxSize: int
minCount: int
proc addObservation*(self:ObservedAddrManager, observedAddr: MultiAddress): bool =
## Adds a new observed MultiAddress. If the number of observations exceeds maxSize, the oldest one is removed.
if self.observedIPsAndPorts.len >= self.maxSize:
self.observedIPsAndPorts.del(0)
self.observedIPsAndPorts.add(observedAddr)
return true
proc getProtocol(self: ObservedAddrManager, observations: seq[MultiAddress], multiCodec: MultiCodec): Opt[MultiAddress] =
var countTable = toCountTable(observations)
countTable.sort()
var orderedPairs = toSeq(countTable.pairs)
for (ma, count) in orderedPairs:
let maFirst = ma[0].get()
if maFirst.protoCode.get() == multiCodec and count >= self.minCount:
return Opt.some(ma)
return Opt.none(MultiAddress)
proc getMostObservedProtocol(self: ObservedAddrManager, multiCodec: MultiCodec): Opt[MultiAddress] =
## Returns the most observed IP address or none if the number of observations are less than minCount.
let observedIPs = self.observedIPsAndPorts.mapIt(it[0].get())
return self.getProtocol(observedIPs, multiCodec)
proc getMostObservedProtoAndPort(self: ObservedAddrManager, multiCodec: MultiCodec): Opt[MultiAddress] =
## Returns the most observed IP/Port address or none if the number of observations are less than minCount.
return self.getProtocol(self.observedIPsAndPorts, multiCodec)
proc getMostObservedProtosAndPorts*(self: ObservedAddrManager): seq[MultiAddress] =
## Returns the most observed IP4/Port and IP6/Port address or an empty seq if the number of observations
## are less than minCount.
var res: seq[MultiAddress]
let ip4 = self.getMostObservedProtoAndPort(multiCodec("ip4"))
if ip4.isSome():
res.add(ip4.get())
let ip6 = self.getMostObservedProtoAndPort(multiCodec("ip6"))
if ip6.isSome():
res.add(ip6.get())
return res
proc guessDialableAddr*(
self: ObservedAddrManager,
ma: MultiAddress): MultiAddress =
## Replaces the first proto valeu of each listen address by the corresponding (matching the proto code) most observed value.
## If the most observed value is not available, the original MultiAddress is returned.
try:
let maFirst = ma[0]
let maRest = ma[1..^1]
if maRest.isErr():
return ma
let observedIP = self.getMostObservedProtocol(maFirst.get().protoCode().get())
return
if observedIP.isNone() or maFirst.get() == observedIP.get():
ma
else:
observedIP.get() & maRest.get()
except CatchableError as error:
debug "Error while handling manual port forwarding", msg = error.msg
return ma
proc `$`*(self: ObservedAddrManager): string =
## Returns a string representation of the ObservedAddrManager.
return "IPs and Ports: " & $self.observedIPsAndPorts
proc new*(
T: typedesc[ObservedAddrManager],
maxSize = 10,
minCount = 3): T =
## Creates a new ObservedAddrManager.
return T(
observedIPsAndPorts: newSeq[MultiAddress](),
maxSize: maxSize,
minCount: minCount)

View File

@@ -28,11 +28,16 @@ else:
import
std/[tables, sets, options, macros],
chronos,
./crypto/crypto,
./protocols/identify,
./protocols/protocol,
./peerid, ./peerinfo,
./routing_record,
./multiaddress,
./stream/connection,
./multistream,
./muxers/muxer,
utility
type
@@ -70,11 +75,15 @@ type
PeerStore* {.public.} = ref object
books: Table[string, BasePeerBook]
identify: Identify
capacity*: int
toClean*: seq[PeerId]
proc new*(T: type PeerStore, capacity = 1000): PeerStore {.public.} =
T(capacity: capacity)
proc new*(T: type PeerStore, identify: Identify, capacity = 1000): PeerStore {.public.} =
T(
identify: identify,
capacity: capacity
)
#########################
# Generic Peer Book API #
@@ -186,3 +195,34 @@ proc cleanup*(
while peerStore.toClean.len > peerStore.capacity:
peerStore.del(peerStore.toClean[0])
peerStore.toClean.delete(0)
proc identify*(
peerStore: PeerStore,
muxer: Muxer) {.async.} =
# new stream for identify
var stream = await muxer.newStream()
if stream == nil:
return
try:
if (await MultistreamSelect.select(stream, peerStore.identify.codec())):
let info = await peerStore.identify.identify(stream, stream.peerId)
when defined(libp2p_agents_metrics):
var knownAgent = "unknown"
if info.agentVersion.isSome and info.agentVersion.get().len > 0:
let shortAgent = info.agentVersion.get().split("/")[0].safeToLowerAscii()
if shortAgent.isOk() and KnownLibP2PAgentsSeq.contains(shortAgent.get()):
knownAgent = shortAgent.get()
muxer.connection.setShortAgent(knownAgent)
peerStore.updatePeerInfo(info)
finally:
await stream.closeWithEOF()
proc getMostObservedProtosAndPorts*(self: PeerStore): seq[MultiAddress] =
return self.identify.observedAddrManager.getMostObservedProtosAndPorts()
proc guessDialableAddr*(self: PeerStore, ma: MultiAddress): MultiAddress =
return self.identify.observedAddrManager.guessDialableAddr(ma)

View File

@@ -72,12 +72,12 @@ type
hint | hint32 | hint64 | float32 | float64
const
SupportedWireTypes* = {
int(ProtoFieldKind.Varint),
int(ProtoFieldKind.Fixed64),
int(ProtoFieldKind.Length),
int(ProtoFieldKind.Fixed32)
}
SupportedWireTypes* = @[
uint64(ProtoFieldKind.Varint),
uint64(ProtoFieldKind.Fixed64),
uint64(ProtoFieldKind.Length),
uint64(ProtoFieldKind.Fixed32)
]
template checkFieldNumber*(i: int) =
doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999),

View File

@@ -12,16 +12,14 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[options, sets, sequtils]
import stew/[results, objects]
import std/options
import stew/results
import chronos, chronicles
import ../../../switch,
../../../multiaddress,
../../../peerid
import core
export core
logScope:
topics = "libp2p autonat"
@@ -38,7 +36,7 @@ proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.}
method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[MultiAddress] = newSeq[MultiAddress]()):
Future[MultiAddress] {.base, async.} =
proc getResponseOrRaise(autonatMsg: Option[AutonatMsg]): AutonatDialResponse {.raises: [UnpackError, AutonatError].} =
proc getResponseOrRaise(autonatMsg: Option[AutonatMsg]): AutonatDialResponse {.raises: [Defect, AutonatError].} =
if autonatMsg.isNone() or
autonatMsg.get().msgType != DialResponse or
autonatMsg.get().response.isNone() or
@@ -58,10 +56,14 @@ method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[Mult
raise newException(AutonatError, "Unexpected error when dialling: " & err.msg, err)
# To bypass maxConnectionsPerPeer
let incomingConnection = switch.connManager.expectConnection(pid)
let incomingConnection = switch.connManager.expectConnection(pid, In)
if incomingConnection.failed() and incomingConnection.error of AlreadyExpectingConnectionError:
raise newException(AutonatError, incomingConnection.error.msg)
defer:
await conn.close()
incomingConnection.cancel() # Safer to always try to cancel cause we aren't sure if the peer dialled us or not
if incomingConnection.completed():
await (await incomingConnection).connection.close()
trace "sending Dial", addrs = switch.peerInfo.addrs
await conn.sendDial(switch.peerInfo.peerId, switch.peerInfo.addrs)
let response = getResponseOrRaise(AutonatMsg.decode(await conn.readLp(1024)))

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[options, sets, sequtils]
import std/[options]
import stew/[results, objects]
import chronos, chronicles
import ../../../multiaddress,
@@ -58,6 +58,9 @@ type
dial*: Option[AutonatDial]
response*: Option[AutonatDialResponse]
NetworkReachability* {.pure.} = enum
Unknown, NotReachable, Reachable
proc encode(p: AutonatPeerInfo): ProtoBuffer =
result = initProtoBuffer()
if p.id.isSome():

View File

@@ -14,7 +14,7 @@ else:
import std/[options, sets, sequtils]
import stew/results
import chronos, chronicles, stew/objects
import chronos, chronicles
import ../../protocol,
../../../switch,
../../../multiaddress,
@@ -63,7 +63,10 @@ proc tryDial(autonat: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.asy
var futs: seq[Future[Opt[MultiAddress]]]
try:
# This is to bypass the per peer max connections limit
let outgoingConnection = autonat.switch.connManager.expectConnection(conn.peerId)
let outgoingConnection = autonat.switch.connManager.expectConnection(conn.peerId, Out)
if outgoingConnection.failed() and outgoingConnection.error of AlreadyExpectingConnectionError:
await conn.sendResponseError(DialRefused, outgoingConnection.error.msg)
return
# Safer to always try to cancel cause we aren't sure if the connection was established
defer: outgoingConnection.cancel()
# tryDial is to bypass the global max connections limit

View File

@@ -15,10 +15,14 @@ else:
import std/[options, deques, sequtils]
import chronos, metrics
import ../../../switch
import ../../../wire
import client
from core import NetworkReachability, AutonatUnreachableError
import ../../../utils/heartbeat
import ../../../crypto/crypto
export options, core.NetworkReachability
logScope:
topics = "libp2p autonatservice"
@@ -27,8 +31,9 @@ declarePublicGauge(libp2p_autonat_reachability_confidence, "autonat reachability
type
AutonatService* = ref object of Service
newConnectedPeerHandler: PeerEventHandler
addressMapper: AddressMapper
scheduleHandle: Future[void]
networkReachability: NetworkReachability
networkReachability*: NetworkReachability
confidence: Option[float]
answers: Deque[NetworkReachability]
autonatClient: AutonatClient
@@ -40,9 +45,7 @@ type
maxQueueSize: int
minConfidence: float
dialTimeout: Duration
NetworkReachability* {.pure.} = enum
NotReachable, Reachable, Unknown
enableAddressMapper: bool
StatusAndConfidenceHandler* = proc (networkReachability: NetworkReachability, confidence: Option[float]): Future[void] {.gcsafe, raises: [Defect].}
@@ -55,7 +58,8 @@ proc new*(
numPeersToAsk: int = 5,
maxQueueSize: int = 10,
minConfidence: float = 0.3,
dialTimeout = 30.seconds): T =
dialTimeout = 30.seconds,
enableAddressMapper = true): T =
return T(
scheduleInterval: scheduleInterval,
networkReachability: Unknown,
@@ -67,10 +71,8 @@ proc new*(
numPeersToAsk: numPeersToAsk,
maxQueueSize: maxQueueSize,
minConfidence: minConfidence,
dialTimeout: dialTimeout)
proc networkReachability*(self: AutonatService): NetworkReachability {.inline.} =
return self.networkReachability
dialTimeout: dialTimeout,
enableAddressMapper: enableAddressMapper)
proc callHandler(self: AutonatService) {.async.} =
if not isNil(self.statusAndConfidenceHandler):
@@ -80,6 +82,9 @@ proc hasEnoughIncomingSlots(switch: Switch): bool =
# we leave some margin instead of comparing to 0 as a peer could connect to us while we are asking for the dial back
return switch.connManager.slotsAvailable(In) >= 2
proc doesPeerHaveIncomingConn(switch: Switch, peerId: PeerId): bool =
return switch.connManager.selectMuxer(peerId, In) != nil
proc handleAnswer(self: AutonatService, ans: NetworkReachability) {.async.} =
if ans == Unknown:
@@ -104,6 +109,10 @@ proc handleAnswer(self: AutonatService, ans: NetworkReachability) {.async.} =
proc askPeer(self: AutonatService, switch: Switch, peerId: PeerId): Future[NetworkReachability] {.async.} =
logScope:
peerId = $peerId
if doesPeerHaveIncomingConn(switch, peerId):
return Unknown
if not hasEnoughIncomingSlots(switch):
debug "No incoming slots available, not asking peer", incomingSlotsAvailable=switch.connManager.slotsAvailable(In)
return Unknown
@@ -126,6 +135,7 @@ proc askPeer(self: AutonatService, switch: Switch, peerId: PeerId): Future[Netwo
await self.handleAnswer(ans)
if not isNil(self.statusAndConfidenceHandler):
await self.statusAndConfidenceHandler(self.networkReachability, self.confidence)
await switch.peerInfo.update()
return ans
proc askConnectedPeers(self: AutonatService, switch: Switch) {.async.} =
@@ -146,19 +156,41 @@ proc schedule(service: AutonatService, switch: Switch, interval: Duration) {.asy
heartbeat "Scheduling AutonatService run", interval:
await service.run(switch)
proc addressMapper(
self: AutonatService,
peerStore: PeerStore,
listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]] {.gcsafe, async.} =
if self.networkReachability != NetworkReachability.Reachable:
return listenAddrs
var addrs = newSeq[MultiAddress]()
for listenAddr in listenAddrs:
var processedMA = listenAddr
try:
let hostIP = initTAddress(listenAddr).get()
if not hostIP.isGlobal() and self.networkReachability == NetworkReachability.Reachable:
processedMA = peerStore.guessDialableAddr(listenAddr) # handle manual port forwarding
except CatchableError as exc:
debug "Error while handling address mapper", msg = exc.msg
addrs.add(processedMA)
return addrs
method setup*(self: AutonatService, switch: Switch): Future[bool] {.async.} =
self.addressMapper = proc (listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]] {.gcsafe, async.} =
return await addressMapper(self, switch.peerStore, listenAddrs)
info "Setting up AutonatService"
let hasBeenSetup = await procCall Service(self).setup(switch)
if hasBeenSetup:
if self.askNewConnectedPeers:
self.newConnectedPeerHandler = proc (peerId: PeerId, event: PeerEvent): Future[void] {.async.} =
if switch.connManager.selectConn(peerId, In) != nil: # no need to ask an incoming peer
return
discard askPeer(self, switch, peerId)
await self.callHandler()
switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)
if self.scheduleInterval.isSome():
self.scheduleHandle = schedule(self, switch, self.scheduleInterval.get())
if self.enableAddressMapper:
switch.peerInfo.addressMappers.add(self.addressMapper)
return hasBeenSetup
method run*(self: AutonatService, switch: Switch) {.async, public.} =
@@ -166,7 +198,6 @@ method run*(self: AutonatService, switch: Switch) {.async, public.} =
await askConnectedPeers(self, switch)
await self.callHandler()
method stop*(self: AutonatService, switch: Switch): Future[bool] {.async, public.} =
info "Stopping AutonatService"
let hasBeenStopped = await procCall Service(self).stop(switch)
@@ -176,6 +207,9 @@ method stop*(self: AutonatService, switch: Switch): Future[bool] {.async, public
self.scheduleHandle = nil
if not isNil(self.newConnectedPeerHandler):
switch.connManager.removePeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)
if self.enableAddressMapper:
switch.peerInfo.addressMappers.keepItIf(it != self.addressMapper)
await switch.peerInfo.update()
return hasBeenStopped
proc statusAndConfidenceHandler*(self: AutonatService, statusAndConfidenceHandler: StatusAndConfidenceHandler) =

View File

@@ -0,0 +1,92 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/sequtils
import stew/results
import chronos, chronicles
import core
import ../../protocol,
../../../stream/connection,
../../../switch,
../../../utils/future
export DcutrError
type
DcutrClient* = ref object
connectTimeout: Duration
maxDialableAddrs: int
logScope:
topics = "libp2p dcutrclient"
proc new*(T: typedesc[DcutrClient], connectTimeout = 15.seconds, maxDialableAddrs = 8): T =
return T(connectTimeout: connectTimeout, maxDialableAddrs: maxDialableAddrs)
proc startSync*(self: DcutrClient, switch: Switch, remotePeerId: PeerId, addrs: seq[MultiAddress]) {.async.} =
logScope:
peerId = switch.peerInfo.peerId
var
peerDialableAddrs: seq[MultiAddress]
stream: Connection
try:
var ourDialableAddrs = getHolePunchableAddrs(addrs)
if ourDialableAddrs.len == 0:
debug "Dcutr initiator has no supported dialable addresses. Aborting Dcutr.", addrs
return
stream = await switch.dial(remotePeerId, DcutrCodec)
await stream.send(MsgType.Connect, addrs)
debug "Dcutr initiator has sent a Connect message."
let rttStart = Moment.now()
let connectAnswer = DcutrMsg.decode(await stream.readLp(1024))
peerDialableAddrs = getHolePunchableAddrs(connectAnswer.addrs)
if peerDialableAddrs.len == 0:
debug "Dcutr receiver has no supported dialable addresses to connect to. Aborting Dcutr.", addrs=connectAnswer.addrs
return
let rttEnd = Moment.now()
debug "Dcutr initiator has received a Connect message back.", connectAnswer
let halfRtt = (rttEnd - rttStart) div 2'i64
await stream.send(MsgType.Sync, @[])
debug "Dcutr initiator has sent a Sync message."
await sleepAsync(halfRtt)
if peerDialableAddrs.len > self.maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<self.maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false))
try:
discard await anyCompleted(futs).wait(self.connectTimeout)
debug "Dcutr initiator has directly connected to the remote peer."
finally:
for fut in futs: fut.cancel()
except CancelledError as err:
raise err
except AllFuturesFailedError as err:
debug "Dcutr initiator could not connect to the remote peer, all connect attempts failed", peerDialableAddrs, msg = err.msg
raise newException(DcutrError, "Dcutr initiator could not connect to the remote peer, all connect attempts failed", err)
except AsyncTimeoutError as err:
debug "Dcutr initiator could not connect to the remote peer, all connect attempts timed out", peerDialableAddrs, msg = err.msg
raise newException(DcutrError, "Dcutr initiator could not connect to the remote peer, all connect attempts timed out", err)
except CatchableError as err:
debug "Unexpected error when trying direct conn", err = err.msg
raise newException(DcutrError, "Unexpected error when trying a direct conn", err)
finally:
if stream != nil:
await stream.close()

View File

@@ -0,0 +1,63 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/sequtils
import chronos
import stew/objects
import ../../../multiaddress,
../../../errors,
../../../stream/connection
export multiaddress
const
DcutrCodec* = "/libp2p/dcutr"
type
MsgType* = enum
Connect = 100
Sync = 300
DcutrMsg* = object
msgType*: MsgType
addrs*: seq[MultiAddress]
DcutrError* = object of LPError
proc encode*(msg: DcutrMsg): ProtoBuffer =
result = initProtoBuffer()
result.write(1, msg.msgType.uint)
for addr in msg.addrs:
result.write(2, addr)
result.finish()
proc decode*(_: typedesc[DcutrMsg], buf: seq[byte]): DcutrMsg {.raises: [Defect, DcutrError].} =
var
msgTypeOrd: uint32
dcutrMsg: DcutrMsg
var pb = initProtoBuffer(buf)
var r1 = pb.getField(1, msgTypeOrd)
let r2 = pb.getRepeatedField(2, dcutrMsg.addrs)
if r1.isErr or r2.isErr or not checkedEnumAssign(dcutrMsg.msgType, msgTypeOrd):
raise newException(DcutrError, "Received malformed message")
return dcutrMsg
proc send*(conn: Connection, msgType: MsgType, addrs: seq[MultiAddress]) {.async.} =
let pb = DcutrMsg(msgType: msgType, addrs: addrs).encode()
await conn.writeLp(pb.buffer)
proc getHolePunchableAddrs*(addrs: seq[MultiAddress]): seq[MultiAddress] =
addrs.filterIt(TCP.match(it))

View File

@@ -0,0 +1,84 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[options, sets, sequtils]
import stew/[results, objects]
import chronos, chronicles
import core
import ../../protocol,
../../../stream/connection,
../../../switch,
../../../utils/future
export DcutrError
export chronicles
type Dcutr* = ref object of LPProtocol
logScope:
topics = "libp2p dcutr"
proc new*(T: typedesc[Dcutr], switch: Switch, connectTimeout = 15.seconds, maxDialableAddrs = 8): T =
proc handleStream(stream: Connection, proto: string) {.async, gcsafe.} =
var peerDialableAddrs: seq[MultiAddress]
try:
let connectMsg = DcutrMsg.decode(await stream.readLp(1024))
debug "Dcutr receiver received a Connect message.", connectMsg
var ourAddrs = switch.peerStore.getMostObservedProtosAndPorts() # likely empty when the peer is reachable
if ourAddrs.len == 0:
# this list should be the same as the peer's public addrs when it is reachable
ourAddrs = switch.peerInfo.listenAddrs.mapIt(switch.peerStore.guessDialableAddr(it))
var ourDialableAddrs = getHolePunchableAddrs(ourAddrs)
if ourDialableAddrs.len == 0:
debug "Dcutr receiver has no supported dialable addresses. Aborting Dcutr.", ourAddrs
return
await stream.send(MsgType.Connect, ourAddrs)
debug "Dcutr receiver has sent a Connect message back."
let syncMsg = DcutrMsg.decode(await stream.readLp(1024))
debug "Dcutr receiver has received a Sync message.", syncMsg
peerDialableAddrs = getHolePunchableAddrs(connectMsg.addrs)
if peerDialableAddrs.len == 0:
debug "Dcutr initiator has no supported dialable addresses to connect to. Aborting Dcutr.", addrs=connectMsg.addrs
return
if peerDialableAddrs.len > maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.In))
try:
discard await anyCompleted(futs).wait(connectTimeout)
debug "Dcutr receiver has directly connected to the remote peer."
finally:
for fut in futs: fut.cancel()
except CancelledError as err:
raise err
except AllFuturesFailedError as err:
debug "Dcutr receiver could not connect to the remote peer, all connect attempts failed", peerDialableAddrs, msg = err.msg
raise newException(DcutrError, "Dcutr receiver could not connect to the remote peer, all connect attempts failed", err)
except AsyncTimeoutError as err:
debug "Dcutr receiver could not connect to the remote peer, all connect attempts timed out", peerDialableAddrs, msg = err.msg
raise newException(DcutrError, "Dcutr receiver could not connect to the remote peer, all connect attempts timed out", err)
except CatchableError as err:
warn "Unexpected error in dcutr handler", msg = err.msg
raise newException(DcutrError, "Unexpected error in dcutr handler", err)
let self = T()
self.handler = handleStream
self.codec = DcutrCodec
self

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import options, macros, sequtils
import options, macros
import stew/objects
import ../../../peerinfo,
../../../signed_envelope
@@ -115,9 +115,10 @@ proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Option[RelayMessage] =
if r2.get(): rMsg.srcPeer = some(src)
if r3.get(): rMsg.dstPeer = some(dst)
if r4.get():
if statusOrd.int notin StatusV1:
var status: StatusV1
if not checkedEnumAssign(status, statusOrd):
return none(RelayMessage)
rMsg.status = some(StatusV1(statusOrd))
rMsg.status = some(status)
some(rMsg)
# Voucher
@@ -285,9 +286,10 @@ proc decode*(_: typedesc[HopMessage], buf: seq[byte]): Option[HopMessage] =
if r3.get(): msg.reservation = some(reservation)
if r4.get(): msg.limit = limit
if r5.get():
if statusOrd.int notin StatusV2:
var status: StatusV2
if not checkedEnumAssign(status, statusOrd):
return none(HopMessage)
msg.status = some(StatusV2(statusOrd))
msg.status = some(status)
some(msg)
# Circuit Relay V2 Stop Message
@@ -364,7 +366,8 @@ proc decode*(_: typedesc[StopMessage], buf: seq[byte]): Option[StopMessage] =
if r2.get(): msg.peer = some(peer)
if r3.get(): msg.limit = limit
if r4.get():
if statusOrd.int notin StatusV2:
var status: StatusV2
if not checkedEnumAssign(status, statusOrd):
return none(StopMessage)
msg.status = some(StatusV2(statusOrd))
msg.status = some(status)
some(msg)

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import options, sequtils, tables, sugar
import options, sequtils, tables
import chronos, chronicles
@@ -25,7 +25,6 @@ import ./messages,
../../../multicodec,
../../../stream/connection,
../../../protocols/protocol,
../../../transports/transport,
../../../errors,
../../../utils/heartbeat,
../../../signed_envelope
@@ -101,7 +100,7 @@ proc createReserveResponse(
status: some(Ok))
return ok(msg)
proc isRelayed(conn: Connection): bool =
proc isRelayed*(conn: Connection): bool =
var wrappedConn = conn
while not isNil(wrappedConn):
if wrappedConn of RelayConnection:

View File

@@ -26,7 +26,10 @@ import ../protobuf/minprotobuf,
../multiaddress,
../protocols/protocol,
../utility,
../errors
../errors,
../observedaddrmanager
export observedaddrmanager
logScope:
topics = "libp2p identify"
@@ -56,6 +59,7 @@ type
Identify* = ref object of LPProtocol
peerInfo*: PeerInfo
sendSignedPeerRecord*: bool
observedAddrManager*: ObservedAddrManager
IdentifyPushHandler* = proc (
peer: PeerId,
@@ -160,7 +164,8 @@ proc new*(
): T =
let identify = T(
peerInfo: peerInfo,
sendSignedPeerRecord: sendSignedPeerRecord
sendSignedPeerRecord: sendSignedPeerRecord,
observedAddrManager: ObservedAddrManager.new(),
)
identify.init()
identify
@@ -182,7 +187,7 @@ method init*(p: Identify) =
p.handler = handle
p.codec = IdentifyCodec
proc identify*(p: Identify,
proc identify*(self: Identify,
conn: Connection,
remotePeerId: PeerId): Future[IdentifyInfo] {.async, gcsafe.} =
trace "initiating identify", conn
@@ -194,23 +199,25 @@ proc identify*(p: Identify,
let infoOpt = decodeMsg(message)
if infoOpt.isNone():
raise newException(IdentityInvalidMsgError, "Incorrect message received!")
result = infoOpt.get()
if result.pubkey.isSome:
let peer = PeerId.init(result.pubkey.get())
if peer.isErr:
raise newException(IdentityInvalidMsgError, $peer.error)
else:
result.peerId = peer.get()
if peer.get() != remotePeerId:
trace "Peer ids don't match",
remote = peer,
local = remotePeerId
raise newException(IdentityNoMatchError, "Peer ids don't match")
else:
var info = infoOpt.get()
if info.pubkey.isNone():
raise newException(IdentityInvalidMsgError, "No pubkey in identify")
let peer = PeerId.init(info.pubkey.get())
if peer.isErr:
raise newException(IdentityInvalidMsgError, $peer.error)
if peer.get() != remotePeerId:
trace "Peer ids don't match", remote = peer, local = remotePeerId
raise newException(IdentityNoMatchError, "Peer ids don't match")
info.peerId = peer.get()
if info.observedAddr.isSome:
if not self.observedAddrManager.addObservation(info.observedAddr.get()):
debug "Observed address is not valid", observedAddr = info.observedAddr.get()
return info
proc new*(T: typedesc[IdentifyPush], handler: IdentifyPushHandler = nil): T {.public.} =
## Create a IdentifyPush protocol. `handler` will be called every time
## a peer sends us new `PeerInfo`

View File

@@ -15,7 +15,7 @@ else:
{.push raises: [].}
import chronos, chronicles
import bearssl/[rand, hash]
import bearssl/rand
import ../protobuf/minprotobuf,
../peerinfo,
../stream/connection,

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[sequtils, sets, hashes, tables]
import std/[sets, hashes, tables]
import chronos, chronicles, metrics
import ./pubsub,
./pubsubpeer,

View File

@@ -14,7 +14,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[tables, sets, options, sequtils]
import std/[sets, sequtils]
import chronos, chronicles, metrics
import ./pubsub,
./floodsub,
@@ -158,8 +158,7 @@ method onNewPeer(g: GossipSub, peer: PubSubPeer) =
peer.appScore = stats.appScore
peer.behaviourPenalty = stats.behaviourPenalty
peer.iWantBudget = IWantPeerBudget
peer.iHaveBudget = IHavePeerBudget
peer.iHaveBudget = IHavePeerBudget
method onPubSubPeerEvent*(p: GossipSub, peer: PubSubPeer, event: PubSubPeerEvent) {.gcsafe.} =
case event.kind
@@ -340,7 +339,7 @@ proc validateAndRelay(g: GossipSub,
# In theory, if topics are the same in all messages, we could batch - we'd
# also have to be careful to only include validated messages
g.broadcast(toSendPeers, RPCMsg(messages: @[msg]))
trace "forwared message to peers", peers = toSendPeers.len, msgId, peer
trace "forwarded message to peers", peers = toSendPeers.len, msgId, peer
for topic in msg.topicIds:
if topic notin g.topics: continue

View File

@@ -12,10 +12,10 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[tables, sequtils, sets, algorithm]
import std/[tables, sequtils, sets, algorithm, deques]
import chronos, chronicles, metrics
import "."/[types, scoring]
import ".."/[pubsubpeer, peertable, timedcache, mcache, floodsub, pubsub]
import ".."/[pubsubpeer, peertable, mcache, floodsub, pubsub]
import "../rpc"/[messages]
import "../../.."/[peerid, multiaddress, utility, switch, routing_record, signed_envelope, utils/heartbeat]
@@ -31,7 +31,7 @@ declareGauge(libp2p_gossipsub_no_peers_topics, "number of topics in mesh with no
declareGauge(libp2p_gossipsub_low_peers_topics, "number of topics in mesh with at least one but below dlow peers")
declareGauge(libp2p_gossipsub_healthy_peers_topics, "number of topics in mesh with at least dlow peers (but below dhigh)")
declareCounter(libp2p_gossipsub_above_dhigh_condition, "number of above dhigh pruning branches ran", labels = ["topic"])
declareSummary(libp2p_gossipsub_mcache_hit, "ratio of successful IWANT message cache lookups")
declareGauge(libp2p_gossipsub_received_iwants, "received iwants", labels = ["kind"])
proc grafted*(g: GossipSub, p: PubSubPeer, topic: string) {.raises: [Defect].} =
g.withPeerStats(p.peerId) do (stats: var PeerStats):
@@ -130,6 +130,10 @@ proc handleGraft*(g: GossipSub,
continue
if g.mesh.hasPeer(topic, peer):
trace "peer already in mesh", peer, topic
continue
# Check backingOff
# Ignore BackoffSlackTime here, since this only for outbound activity
# and subtract a second time to avoid race conditions
@@ -160,7 +164,8 @@ proc handleGraft*(g: GossipSub,
# If they send us a graft before they send us a subscribe, what should
# we do? For now, we add them to mesh but don't add them to gossipsub.
if topic in g.topics:
if g.mesh.peers(topic) < g.parameters.dHigh or peer.outbound:
if g.mesh.peers(topic) < g.parameters.dHigh or
(peer.outbound and g.mesh.outboundPeers(topic) < g.parameters.dOut):
# In the spec, there's no mention of DHi here, but implicitly, a
# peer will be removed from the mesh on next rebalance, so we don't want
# this peer to push someone else out
@@ -176,6 +181,10 @@ proc handleGraft*(g: GossipSub,
topicID: topic,
peers: g.peerExchangeList(topic),
backoff: g.parameters.pruneBackoff.seconds.uint64))
let backoff = Moment.fromNow(g.parameters.pruneBackoff)
g.backingOff
.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] = backoff
else:
trace "peer grafting topic we're not interested in", peer, topic
# gossip 1.1, we do not send a control message prune anymore
@@ -272,28 +281,30 @@ proc handleIHave*(g: GossipSub,
proc handleIWant*(g: GossipSub,
peer: PubSubPeer,
iwants: seq[ControlIWant]): seq[Message] {.raises: [Defect].} =
var messages: seq[Message]
var
messages: seq[Message]
invalidRequests = 0
if peer.score < g.parameters.gossipThreshold:
trace "iwant: ignoring low score peer", peer, score = peer.score
elif peer.iWantBudget <= 0:
trace "iwant: ignoring out of budget peer", peer, score = peer.score
else:
let deIwants = iwants.deduplicate()
for iwant in deIwants:
let deIwantsMsgs = iwant.messageIds.deduplicate()
for mid in deIwantsMsgs:
for iwant in iwants:
for mid in iwant.messageIds:
trace "peer sent iwant", peer, messageID = mid
# canAskIWant will only return true once for a specific message
if not peer.canAskIWant(mid):
libp2p_gossipsub_received_iwants.inc(1, labelValues=["notsent"])
invalidRequests.inc()
if invalidRequests > 20:
libp2p_gossipsub_received_iwants.inc(1, labelValues=["skipped"])
return messages
continue
let msg = g.mcache.get(mid)
if msg.isSome:
libp2p_gossipsub_mcache_hit.observe(1)
# avoid spam
if peer.iWantBudget > 0:
messages.add(msg.get())
dec peer.iWantBudget
else:
break
libp2p_gossipsub_received_iwants.inc(1, labelValues=["correct"])
messages.add(msg.get())
else:
libp2p_gossipsub_mcache_hit.observe(0)
libp2p_gossipsub_received_iwants.inc(1, labelValues=["unknown"])
return messages
proc commitMetrics(metrics: var MeshMetrics) {.raises: [Defect].} =
@@ -318,10 +329,11 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
var
prunes, grafts: seq[PubSubPeer]
npeers = g.mesh.peers(topic)
nOutPeers = g.mesh.outboundPeers(topic)
defaultMesh: HashSet[PubSubPeer]
backingOff = g.backingOff.getOrDefault(topic)
if npeers < g.parameters.dLow:
if npeers < g.parameters.dLow:
trace "replenishing mesh", peers = npeers
# replenish the mesh if we're below Dlo
@@ -360,7 +372,7 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
g.fanout.removePeer(topic, peer)
grafts &= peer
else:
elif nOutPeers < g.parameters.dOut:
trace "replenishing mesh outbound quota", peers = g.mesh.peers(topic)
var
@@ -388,8 +400,8 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
# sort peers by score, high score first, we are grafting
candidates.sort(byScore, SortOrder.Descending)
# Graft peers so we reach a count of D
candidates.setLen(min(candidates.len, g.parameters.dOut))
# Graft outgoing peers so we reach a count of dOut
candidates.setLen(min(candidates.len, g.parameters.dOut - nOutPeers))
trace "grafting outbound peers", topic, peers = candidates.len
@@ -616,8 +628,11 @@ proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] {.raises:
g.rng.shuffle(allPeers)
allPeers.setLen(target)
let msgIdsAsSet = ihave.messageIds.toHashSet()
for peer in allPeers:
control.mgetOrPut(peer, ControlMessage()).ihave.add(ihave)
peer.sentIHaves[^1].incl(msgIdsAsSet)
libp2p_gossipsub_cache_window_size.set(cacheWindowSize.int64)
@@ -628,7 +643,9 @@ proc onHeartbeat(g: GossipSub) {.raises: [Defect].} =
# reset IHAVE cap
block:
for peer in g.peers.values:
peer.iWantBudget = IWantPeerBudget
peer.sentIHaves.addFirst(default(HashSet[MessageId]))
if peer.sentIHaves.len > g.parameters.historyLength:
discard peer.sentIHaves.popLast()
peer.iHaveBudget = IHavePeerBudget
var meshMetrics = MeshMetrics()

View File

@@ -16,7 +16,7 @@ import std/[tables, sets, options]
import chronos, chronicles, metrics
import "."/[types]
import ".."/[pubsubpeer]
import "../../.."/[peerid, multiaddress, utility, switch, utils/heartbeat]
import "../../.."/[peerid, multiaddress, switch, utils/heartbeat]
logScope:
topics = "libp2p gossipsub"
@@ -250,7 +250,7 @@ proc updateScores*(g: GossipSub) = # avoid async
if g.parameters.disconnectBadPeers and stats.score < g.parameters.graylistThreshold and
peer.peerId notin g.parameters.directPeers:
debug "disconnecting bad score peer", peer, score = peer.score
asyncSpawn(try: g.disconnectPeer(peer) except Exception as exc: raiseAssert exc.msg)
asyncSpawn(g.disconnectPeer(peer))
libp2p_gossipsub_peers_scores.inc(peer.score, labelValues = [agent])
@@ -295,11 +295,11 @@ proc rewardDelivered*(
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
stats.topicInfos.withValue(tt, tstats):
if tstats[].inMesh:
if first:
tstats[].firstMessageDeliveries.addCapped(
1, topicParams.firstMessageDeliveriesCap)
if first:
tstats[].firstMessageDeliveries.addCapped(
1, topicParams.firstMessageDeliveriesCap)
if tstats[].inMesh:
tstats[].meshMessageDeliveries.addCapped(
1, topicParams.meshMessageDeliveriesCap)
do: # make sure we don't loose this information

View File

@@ -13,11 +13,13 @@ else:
{.push raises: [].}
import chronos
import std/[tables, sets]
import std/[options, tables, sets]
import ".."/[floodsub, peertable, mcache, pubsubpeer]
import "../rpc"/[messages]
import "../../.."/[peerid, multiaddress, utility]
export options, tables, sets
const
GossipSubCodec* = "/meshsub/1.1.0"
GossipSubCodec_10* = "/meshsub/1.0.0"
@@ -46,7 +48,6 @@ const
const
BackoffSlackTime* = 2 # seconds
IWantPeerBudget* = 25 # 25 messages per second ( reset every heartbeat )
IHavePeerBudget* = 10
# the max amount of IHave to expose, not by spec, but go as example
# rust sigp: https://github.com/sigp/rust-libp2p/blob/f53d02bc873fef2bf52cd31e3d5ce366a41d8a8c/protocols/gossipsub/src/config.rs#L572

View File

@@ -12,9 +12,11 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[tables, sets]
import std/[tables, sets, sequtils]
import ./pubsubpeer, ../../peerid
export tables, sets
type
PeerTable* = Table[string, HashSet[PubSubPeer]] # topic string to peer map
@@ -51,3 +53,10 @@ func peers*(table: PeerTable, topic: string): int =
except KeyError: raiseAssert "checked with in"
else:
0
func outboundPeers*(table: PeerTable, topic: string): int =
if topic in table:
try: table[topic].countIt(it.outbound)
except KeyError: raiseAssert "checked with in"
else:
0

View File

@@ -36,6 +36,7 @@ import metrics
import stew/results
export results
export tables, sets
export PubSubPeer
export PubSubObserver
export protocol
@@ -118,7 +119,7 @@ type
anonymize*: bool ## if we omit fromPeer and seqno from RPC messages we send
subscriptionValidator*: SubscriptionValidator # callback used to validate subscriptions
topicsHigh*: int ## the maximum number of topics a peer is allowed to subscribe to
maxMessageSize*: int ##\
maxMessageSize*: int ##\
## the maximum raw message size we'll globally allow
## for finer tuning, check message size on topic validator
##
@@ -405,7 +406,11 @@ method onTopicSubscription*(p: PubSub, topic: string, subscribed: bool) {.base,
# Notify others that we are no longer interested in the topic
for _, peer in p.peers:
p.sendSubs(peer, [topic], subscribed)
# If we don't have a sendConn yet, we will
# send the full sub list when we get the sendConn,
# so no need to send it here
if peer.hasSendConn:
p.sendSubs(peer, [topic], subscribed)
if subscribed:
libp2p_pubsub_subscriptions.inc()

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[sequtils, strutils, tables, hashes, options]
import std/[sequtils, strutils, tables, hashes, options, sets, deques]
import stew/results
import chronos, chronicles, nimcrypto/sha2, metrics
import rpc/[messages, message, protobuf],
@@ -62,18 +62,24 @@ type
observers*: ref seq[PubSubObserver] # ref as in smart_ptr
score*: float64
iWantBudget*: int
sentIHaves*: Deque[HashSet[MessageId]]
iHaveBudget*: int
maxMessageSize: int
appScore*: float64 # application specific score
behaviourPenalty*: float64 # the eventual penalty score
when defined(libp2p_agents_metrics):
shortAgent*: string
RPCHandler* = proc(peer: PubSubPeer, msg: RPCMsg): Future[void]
{.gcsafe, raises: [Defect].}
when defined(libp2p_agents_metrics):
func shortAgent*(p: PubSubPeer): string =
if p.sendConn.isNil or p.sendConn.getWrapped().isNil:
"unknown"
else:
#TODO the sendConn is setup before identify,
#so we have to read the parents short agent..
p.sendConn.getWrapped().shortAgent
func hash*(p: PubSubPeer): Hash =
p.peerId.hash
@@ -177,6 +183,10 @@ proc connectOnce(p: PubSubPeer): Future[void] {.async.} =
# stop working so we make an effort to only keep a single channel alive
trace "Get new send connection", p, newConn
# Careful to race conditions here.
# Topic subscription relies on either connectedFut
# to be completed, or onEvent to be called later
p.connectedFut.complete()
p.sendConn = newConn
p.address = if p.sendConn.observedAddr.isSome: some(p.sendConn.observedAddr.get) else: none(MultiAddress)
@@ -217,6 +227,9 @@ proc connect*(p: PubSubPeer) =
asyncSpawn connectImpl(p)
proc hasSendConn*(p: PubSubPeer): bool =
p.sendConn != nil
template sendMetrics(msg: RPCMsg): untyped =
when defined(libp2p_expensive_metrics):
for x in msg.messages:
@@ -279,6 +292,13 @@ proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [Defect].} =
asyncSpawn p.sendEncoded(encoded)
proc canAskIWant*(p: PubSubPeer, msgId: MessageId): bool =
for sentIHave in p.sentIHaves.mitems():
if msgId in sentIHave:
sentIHave.excl(msgId)
return true
return false
proc new*(
T: typedesc[PubSubPeer],
peerId: PeerId,
@@ -287,7 +307,7 @@ proc new*(
codec: string,
maxMessageSize: int): T =
T(
result = T(
getConn: getConn,
onEvent: onEvent,
codec: codec,
@@ -295,3 +315,4 @@ proc new*(
connectedFut: newFuture[void](),
maxMessageSize: maxMessageSize
)
result.sentIHaves.addFirst(default(HashSet[MessageId]))

View File

@@ -12,7 +12,6 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import hashes
import chronicles, metrics, stew/[byteutils, endians2]
import ./messages,
./protobuf,

View File

@@ -14,7 +14,7 @@ else:
import std/[tables]
import chronos/timer
import chronos/timer, stew/results
const Timeout* = 10.seconds # default timeout in ms
@@ -22,6 +22,7 @@ type
TimedEntry*[K] = ref object of RootObj
key: K
addedAt: Moment
expiresAt: Moment
next, prev: TimedEntry[K]
TimedCache*[K] = object of RootObj
@@ -30,15 +31,14 @@ type
timeout: Duration
func expire*(t: var TimedCache, now: Moment = Moment.now()) =
let expirationLimit = now - t.timeout
while t.head != nil and t.head.addedAt < expirationLimit:
while t.head != nil and t.head.expiresAt < now:
t.entries.del(t.head.key)
t.head.prev = nil
t.head = t.head.next
if t.head == nil: t.tail = nil
func del*[K](t: var TimedCache[K], key: K): bool =
# Removes existing key from cache, returning false if it was not present
func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] =
# Removes existing key from cache, returning the previous value if present
var item: TimedEntry[K]
if t.entries.pop(key, item):
if t.head == item: t.head = item.next
@@ -46,9 +46,9 @@ func del*[K](t: var TimedCache[K], key: K): bool =
if item.next != nil: item.next.prev = item.prev
if item.prev != nil: item.prev.next = item.next
true
Opt.some(item)
else:
false
Opt.none(TimedEntry[K])
func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
# Puts k in cache, returning true if the item was already present and false
@@ -56,9 +56,13 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
# refreshed.
t.expire(now)
var res = t.del(k) # Refresh existing item
var previous = t.del(k) # Refresh existing item
let node = TimedEntry[K](key: k, addedAt: now)
let addedAt =
if previous.isSome: previous.get().addedAt
else: now
let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout)
if t.head == nil:
t.tail = node
@@ -66,7 +70,7 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
else:
# search from tail because typically that's where we add when now grows
var cur = t.tail
while cur != nil and node.addedAt < cur.addedAt:
while cur != nil and node.expiresAt < cur.expiresAt:
cur = cur.prev
if cur == nil:
@@ -82,7 +86,7 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
t.entries[k] = node
res
previous.isSome()
func contains*[K](t: TimedCache[K], k: K): bool =
k in t.entries

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[oids, strformat]
import std/strformat
import chronos
import chronicles
import bearssl/[rand, hash]

View File

@@ -56,7 +56,6 @@ proc new*(T: type SecureConn,
peerId: peerId,
observedAddr: observedAddr,
closeEvent: conn.closeEvent,
upgraded: conn.upgraded,
timeout: timeout,
dir: conn.dir)
result.initStream()

View File

@@ -51,10 +51,12 @@ proc decode*(
for address in addressInfos:
var addressInfo = AddressInfo()
let subProto = initProtoBuffer(address)
if ? subProto.getField(1, addressInfo.address) == false:
return err(ProtoError.RequiredFieldMissing)
let f = subProto.getField(1, addressInfo.address)
if f.isOk() and f.get():
record.addresses &= addressInfo
record.addresses &= addressInfo
if record.addresses.len == 0:
return err(ProtoError.RequiredFieldMissing)
ok(record)

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import chronos, chronicles, times, tables, sequtils, options
import chronos, chronicles, times, tables, sequtils
import ../switch,
../protocols/connectivity/relay/[client, utils]
@@ -32,9 +32,15 @@ type
backingOff: seq[PeerId]
peerAvailable: AsyncEvent
onReservation: OnReservationHandler
addressMapper: AddressMapper
rng: ref HmacDrbgContext
proc reserveAndUpdate(self: AutoRelayService, relayPid: PeerId, selfPid: PeerId) {.async.} =
proc addressMapper(
self: AutoRelayService,
listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]] {.gcsafe, async.} =
return concat(toSeq(self.relayAddresses.values))
proc reserveAndUpdate(self: AutoRelayService, relayPid: PeerId, switch: Switch) {.async.} =
while self.running:
let
rsvp = await self.client.reserve(relayPid).wait(chronos.seconds(5))
@@ -46,11 +52,16 @@ proc reserveAndUpdate(self: AutoRelayService, relayPid: PeerId, selfPid: PeerId)
break
if relayPid notin self.relayAddresses or self.relayAddresses[relayPid] != relayedAddr:
self.relayAddresses[relayPid] = relayedAddr
await switch.peerInfo.update()
debug "Updated relay addresses", relayPid, relayedAddr
if not self.onReservation.isNil():
self.onReservation(concat(toSeq(self.relayAddresses.values)))
await sleepAsync chronos.seconds(ttl - 30)
method setup*(self: AutoRelayService, switch: Switch): Future[bool] {.async, gcsafe.} =
self.addressMapper = proc (listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]] {.gcsafe, async.} =
return await addressMapper(self, listenAddrs)
let hasBeenSetUp = await procCall Service(self).setup(switch)
if hasBeenSetUp:
proc handlePeerJoined(peerId: PeerId, event: PeerEvent) {.async.} =
@@ -63,6 +74,7 @@ method setup*(self: AutoRelayService, switch: Switch): Future[bool] {.async, gcs
future[].cancel()
switch.addPeerEventHandler(handlePeerJoined, Joined)
switch.addPeerEventHandler(handlePeerLeft, Left)
switch.peerInfo.addressMappers.add(self.addressMapper)
await self.run(switch)
return hasBeenSetUp
@@ -96,7 +108,7 @@ proc innerRun(self: AutoRelayService, switch: Switch) {.async, gcsafe.} =
for relayPid in connectedPeers:
if self.relayPeers.len() >= self.numRelays:
break
self.relayPeers[relayPid] = self.reserveAndUpdate(relayPid, switch.peerInfo.peerId)
self.relayPeers[relayPid] = self.reserveAndUpdate(relayPid, switch)
if self.relayPeers.len() > 0:
await one(toSeq(self.relayPeers.values())) or self.peerAvailable.wait()
@@ -116,6 +128,8 @@ method stop*(self: AutoRelayService, switch: Switch): Future[bool] {.async, gcsa
if hasBeenStopped:
self.running = false
self.runner.cancel()
switch.peerInfo.addressMappers.keepItIf(it != self.addressMapper)
await switch.peerInfo.update()
return hasBeenStopped
proc getAddresses*(self: AutoRelayService): seq[MultiAddress] =

View File

@@ -0,0 +1,129 @@
# Nim-LibP2P
# Copyright (c) 2022 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[tables, sequtils]
import chronos, chronicles
import ../switch, ../wire
import ../protocols/rendezvous
import ../services/autorelayservice
import ../protocols/connectivity/relay/relay
import ../protocols/connectivity/autonat/service
import ../protocols/connectivity/dcutr/[client, server]
import ../multicodec
logScope:
topics = "libp2p hpservice"
type
HPService* = ref object of Service
newConnectedPeerHandler: PeerEventHandler
onNewStatusHandler: StatusAndConfidenceHandler
autoRelayService: AutoRelayService
autonatService: AutonatService
isPublicIPAddrProc: IsPublicIPAddrProc
IsPublicIPAddrProc* = proc(ta: TransportAddress): bool {.gcsafe, raises: [Defect].}
proc new*(T: typedesc[HPService], autonatService: AutonatService, autoRelayService: AutoRelayService,
isPublicIPAddrProc: IsPublicIPAddrProc = isGlobal): T =
return T(autonatService: autonatService, autoRelayService: autoRelayService, isPublicIPAddrProc: isPublicIPAddrProc)
proc tryStartingDirectConn(self: HPService, switch: Switch, peerId: PeerId): Future[bool] {.async.} =
proc tryConnect(address: MultiAddress): Future[bool] {.async.} =
debug "Trying to create direct connection", peerId, address
await switch.connect(peerId, @[address], true, false)
debug "Direct connection created."
return true
await sleepAsync(500.milliseconds) # wait for AddressBook to be populated
for address in switch.peerStore[AddressBook][peerId]:
try:
let isRelayed = address.contains(multiCodec("p2p-circuit"))
if isRelayed.isErr() or isRelayed.get():
continue
if DNS.matchPartial(address):
return await tryConnect(address)
else:
let ta = initTAddress(address)
if ta.isOk() and self.isPublicIPAddrProc(ta.get()):
return await tryConnect(address)
except CatchableError as err:
debug "Failed to create direct connection.", err = err.msg
continue
return false
proc closeRelayConn(relayedConn: Connection) {.async.} =
await sleepAsync(2000.milliseconds) # grace period before closing relayed connection
await relayedConn.close()
proc newConnectedPeerHandler(self: HPService, switch: Switch, peerId: PeerId, event: PeerEvent) {.async.} =
try:
# Get all connections to the peer. If there is at least one non-relayed connection, return.
let connections = switch.connManager.getConnections()[peerId].mapIt(it.connection)
if connections.anyIt(not isRelayed(it)):
return
let incomingRelays = connections.filterIt(it.transportDir == Direction.In)
if incomingRelays.len == 0:
return
let relayedConn = incomingRelays[0]
if await self.tryStartingDirectConn(switch, peerId):
await closeRelayConn(relayedConn)
return
let dcutrClient = DcutrClient.new()
var natAddrs = switch.peerStore.getMostObservedProtosAndPorts()
if natAddrs.len == 0:
natAddrs = switch.peerInfo.listenAddrs.mapIt(switch.peerStore.guessDialableAddr(it))
await dcutrClient.startSync(switch, peerId, natAddrs)
await closeRelayConn(relayedConn)
except CatchableError as err:
debug "Hole punching failed during dcutr", err = err.msg
method setup*(self: HPService, switch: Switch): Future[bool] {.async.} =
var hasBeenSetup = await procCall Service(self).setup(switch)
hasBeenSetup = hasBeenSetup and await self.autonatService.setup(switch)
if hasBeenSetup:
let dcutrProto = Dcutr.new(switch)
switch.mount(dcutrProto)
self.newConnectedPeerHandler = proc (peerId: PeerId, event: PeerEvent) {.async.} =
await newConnectedPeerHandler(self, switch, peerId, event)
switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)
self.onNewStatusHandler = proc (networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.NotReachable:
discard await self.autoRelayService.setup(switch)
elif networkReachability == NetworkReachability.Reachable:
discard await self.autoRelayService.stop(switch)
# We do it here instead of in the AutonatService because this is useful only when hole punching.
for t in switch.transports:
t.networkReachability = networkReachability
self.autonatService.statusAndConfidenceHandler(self.onNewStatusHandler)
return hasBeenSetup
method run*(self: HPService, switch: Switch) {.async, public.} =
await self.autonatService.run(switch)
method stop*(self: HPService, switch: Switch): Future[bool] {.async, public.} =
discard await self.autonatService.stop(switch)
if not isNil(self.newConnectedPeerHandler):
switch.connManager.removePeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)

View File

@@ -18,9 +18,6 @@ import chronos, chronicles, metrics
import ../stream/connection
import ./streamseq
when chronicles.enabledLogLevel == LogLevel.TRACE:
import oids
export connection
logScope:

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[oids, strformat]
import std/[strformat]
import stew/results
import chronos, chronicles, metrics
import connection
@@ -103,11 +103,11 @@ method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.
withExceptions:
result = await s.client.readOnce(pbytes, nbytes)
s.activity = true # reset activity flag
libp2p_network_bytes.inc(nbytes.int64, labelValues = ["in"])
libp2p_network_bytes.inc(result.int64, labelValues = ["in"])
when defined(libp2p_agents_metrics):
s.trackPeerIdentity()
if s.tracked:
libp2p_peers_traffic_read.inc(nbytes.int64, labelValues = [s.shortAgent])
libp2p_peers_traffic_read.inc(result.int64, labelValues = [s.shortAgent])
proc completeWrite(
s: ChronosStream, fut: Future[int], msgLen: int): Future[void] {.async.} =

View File

@@ -39,7 +39,6 @@ type
timeoutHandler*: TimeoutHandler # timeout handler
peerId*: PeerId
observedAddr*: Opt[MultiAddress]
upgraded*: Future[void]
protocol*: string # protocol used by the connection, used as tag for metrics
transportDir*: Direction # The bottom level transport (generally the socket) direction
when defined(libp2p_agents_metrics):
@@ -47,22 +46,6 @@ type
proc timeoutMonitor(s: Connection) {.async, gcsafe.}
proc isUpgraded*(s: Connection): bool =
if not isNil(s.upgraded):
return s.upgraded.finished
proc upgrade*(s: Connection, failed: ref CatchableError = nil) =
if not isNil(s.upgraded):
if not isNil(failed):
s.upgraded.fail(failed)
return
s.upgraded.complete()
proc onUpgrade*(s: Connection) {.async.} =
if not isNil(s.upgraded):
await s.upgraded
func shortLog*(conn: Connection): string =
try:
if conn.isNil: "Connection(nil)"
@@ -80,9 +63,6 @@ method initStream*(s: Connection) =
doAssert(isNil(s.timerTaskFut))
if isNil(s.upgraded):
s.upgraded = newFuture[void]()
if s.timeout > 0.millis:
trace "Monitoring for timeout", s, timeout = s.timeout
@@ -100,10 +80,6 @@ method closeImpl*(s: Connection): Future[void] =
s.timerTaskFut.cancel()
s.timerTaskFut = nil
if not isNil(s.upgraded) and not s.upgraded.finished:
s.upgraded.cancel()
s.upgraded = nil
trace "Closed connection", s
procCall LPStream(s).closeImpl()
@@ -158,6 +134,13 @@ proc timeoutMonitor(s: Connection) {.async, gcsafe.} =
method getWrapped*(s: Connection): Connection {.base.} =
doAssert(false, "not implemented!")
when defined(libp2p_agents_metrics):
proc setShortAgent*(s: Connection, shortAgent: string) =
var conn = s
while not isNil(conn):
conn.shortAgent = shortAgent
conn = conn.getWrapped()
proc new*(C: type Connection,
peerId: PeerId,
dir: Direction,

View File

@@ -20,9 +20,7 @@ import std/[tables,
options,
sequtils,
sets,
oids,
sugar,
math]
oids]
import chronos,
chronicles,
@@ -30,14 +28,12 @@ import chronos,
import stream/connection,
transports/transport,
upgrademngrs/[upgrade, muxedupgrade],
upgrademngrs/upgrade,
multistream,
multiaddress,
protocols/protocol,
protocols/secure/secure,
peerinfo,
protocols/identify,
muxers/muxer,
utils/semaphore,
connmanager,
nameresolving/nameresolver,
@@ -58,8 +54,6 @@ logScope:
# and only if the channel has been secured (i.e. if a secure manager has been
# previously provided)
declareCounter(libp2p_failed_upgrades_incoming, "incoming connections failed upgrades")
const
ConcurrentUpgrades* = 4
@@ -149,10 +143,11 @@ method connect*(
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true): Future[void] {.public.} =
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.public.} =
## Connects to a peer without opening a stream to it
s.dialer.connect(peerId, addrs, forceDial, reuseConnection)
s.dialer.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
method connect*(
s: Switch,
@@ -220,24 +215,27 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil)
s.ms.addHandler(proto.codecs, proto, matcher)
s.peerInfo.protocols.add(proto.codec)
proc upgradeMonitor(conn: Connection, upgrades: AsyncSemaphore) {.async.} =
## monitor connection for upgrades
##
proc upgrader(switch: Switch, trans: Transport, conn: Connection) {.async.} =
let muxed = await trans.upgrade(conn, Direction.In, Opt.none(PeerId))
switch.connManager.storeMuxer(muxed)
await switch.peerStore.identify(muxed)
trace "Connection upgrade succeeded"
proc upgradeMonitor(
switch: Switch,
trans: Transport,
conn: Connection,
upgrades: AsyncSemaphore) {.async.} =
try:
# Since we don't control the flow of the
# upgrade, this timeout guarantees that a
# "hanged" remote doesn't hold the upgrade
# forever
await conn.onUpgrade.wait(30.seconds) # wait for connection to be upgraded
trace "Connection upgrade succeeded"
await switch.upgrader(trans, conn).wait(30.seconds)
except CatchableError as exc:
libp2p_failed_upgrades_incoming.inc()
if exc isnot CancelledError:
libp2p_failed_upgrades_incoming.inc()
if not isNil(conn):
await conn.close()
trace "Exception awaiting connection upgrade", exc = exc.msg, conn
finally:
upgrades.release() # don't forget to release the slot!
upgrades.release()
proc accept(s: Switch, transport: Transport) {.async.} = # noraises
## switch accept loop, ran for every transport
@@ -278,8 +276,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises
conn.transportDir = Direction.In
debug "Accepted an incoming connection", conn
asyncSpawn upgradeMonitor(conn, upgrades)
asyncSpawn transport.upgradeIncoming(conn)
asyncSpawn s.upgradeMonitor(transport, conn, upgrades)
except CancelledError as exc:
trace "releasing semaphore on cancellation"
upgrades.release() # always release the slot
@@ -338,7 +335,7 @@ proc start*(s: Switch) {.async, gcsafe, public.} =
warn "Switch has already been started"
return
trace "starting switch for peer", peerInfo = s.peerInfo
debug "starting switch for peer", peerInfo = s.peerInfo
var startFuts: seq[Future[void]]
for t in s.transports:
let addrs = s.peerInfo.listenAddrs.filterIt(
@@ -377,14 +374,13 @@ proc start*(s: Switch) {.async, gcsafe, public.} =
proc newSwitch*(peerInfo: PeerInfo,
transports: seq[Transport],
identity: Identify,
secureManagers: openArray[Secure] = [],
connManager: ConnManager,
ms: MultistreamSelect,
peerStore: PeerStore,
nameResolver: NameResolver = nil,
peerStore = PeerStore.new(),
services = newSeq[Service]()): Switch
{.raises: [Defect, LPError], public.} =
{.raises: [Defect, LPError].} =
if secureManagers.len == 0:
raise newException(LPError, "Provide at least one secure manager")
@@ -394,11 +390,9 @@ proc newSwitch*(peerInfo: PeerInfo,
transports: transports,
connManager: connManager,
peerStore: peerStore,
dialer: Dialer.new(peerInfo.peerId, connManager, transports, ms, nameResolver),
dialer: Dialer.new(peerInfo.peerId, connManager, peerStore, transports, nameResolver),
nameResolver: nameResolver,
services: services)
switch.connManager.peerStore = peerStore
switch.mount(identity)
return switch

View File

@@ -14,14 +14,13 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[oids, sequtils]
import std/[sequtils]
import stew/results
import chronos, chronicles
import transport,
../errors,
../wire,
../multicodec,
../multistream,
../connmanager,
../multiaddress,
../stream/connection,
@@ -42,12 +41,15 @@ type
servers*: seq[StreamServer]
clients: array[Direction, seq[StreamTransport]]
flags: set[ServerFlags]
clientFlags: set[SocketFlags]
acceptFuts: seq[Future[StreamTransport]]
TcpTransportTracker* = ref object of TrackerBase
opened*: uint64
closed*: uint64
TcpTransportError* = object of transport.TransportError
proc setupTcpTransportTracker(): TcpTransportTracker {.gcsafe, raises: [Defect].}
proc getTcpTransportTracker(): TcpTransportTracker {.gcsafe.} =
@@ -129,9 +131,20 @@ proc new*(
flags: set[ServerFlags] = {},
upgrade: Upgrade): T {.public.} =
let transport = T(
flags: flags,
upgrader: upgrade)
let
transport = T(
flags: flags,
clientFlags:
if ServerFlags.TcpNoDelay in flags:
compilesOr:
{SocketFlags.TcpNoDelay}
do:
doAssert(false)
default(set[SocketFlags])
else:
default(set[SocketFlags]),
upgrader: upgrade,
networkReachability: NetworkReachability.Unknown)
return transport
@@ -154,6 +167,7 @@ method start*(
trace "Invalid address detected, skipping!", address = ma
continue
self.flags.incl(ServerFlags.ReusePort)
let server = createStreamServer(
ma = ma,
flags = self.flags,
@@ -252,8 +266,13 @@ method dial*(
##
trace "Dialing remote peer", address = $address
let transp =
if self.networkReachability == NetworkReachability.NotReachable and self.addrs.len > 0:
self.clientFlags.incl(SocketFlags.ReusePort)
await connect(address, flags = self.clientFlags, localAddress = Opt.some(self.addrs[0]))
else:
await connect(address, flags = self.clientFlags)
let transp = await connect(address)
try:
let observedAddr = await getObservedAddr(transp)
return await self.connHandler(transp, Opt.some(observedAddr), Direction.Out)

View File

@@ -269,7 +269,7 @@ proc new*(
transports: switch.transports,
connManager: switch.connManager,
peerStore: switch.peerStore,
dialer: Dialer.new(switch.peerInfo.peerId, switch.connManager, switch.transports, switch.ms, nil),
dialer: Dialer.new(switch.peerInfo.peerId, switch.connManager, switch.peerStore, switch.transports, nil),
nameResolver: nil)
torSwitch.connManager.peerStore = switch.peerStore

View File

@@ -18,7 +18,11 @@ import chronos, chronicles
import ../stream/connection,
../multiaddress,
../multicodec,
../upgrademngrs/upgrade
../muxers/muxer,
../upgrademngrs/upgrade,
../protocols/connectivity/autonat/core
export core.NetworkReachability
logScope:
topics = "libp2p transport"
@@ -32,6 +36,7 @@ type
addrs*: seq[MultiAddress]
running*: bool
upgrader*: Upgrade
networkReachability*: NetworkReachability
proc newTransportClosedError*(parent: ref Exception = nil): ref LPError =
newException(TransportClosedError,
@@ -78,24 +83,16 @@ proc dial*(
peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.gcsafe.} =
self.dial("", address)
method upgradeIncoming*(
self: Transport,
conn: Connection): Future[void] {.base, gcsafe.} =
## base upgrade method that the transport uses to perform
## transport specific upgrades
##
self.upgrader.upgradeIncoming(conn)
method upgradeOutgoing*(
method upgrade*(
self: Transport,
conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.base, gcsafe.} =
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} =
## base upgrade method that the transport uses to perform
## transport specific upgrades
##
self.upgrader.upgradeOutgoing(conn, peerId)
self.upgrader.upgrade(conn, direction, peerId)
method handles*(
self: Transport,

View File

@@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[tables, sequtils]
import std/sequtils
import pkg/[chronos, chronicles, metrics]
import ../upgrademngrs/upgrade,
@@ -30,35 +30,24 @@ type
proc getMuxerByCodec(self: MuxedUpgrade, muxerName: string): MuxerProvider =
for m in self.muxers:
if muxerName in m.codecs:
if muxerName == m.codec:
return m
proc identify*(
self: MuxedUpgrade,
muxer: Muxer) {.async, gcsafe.} =
# new stream for identify
var stream = await muxer.newStream()
if stream == nil:
return
try:
await self.identify(stream)
when defined(libp2p_agents_metrics):
muxer.connection.shortAgent = stream.shortAgent
finally:
await stream.closeWithEOF()
proc mux*(
self: MuxedUpgrade,
conn: Connection): Future[Muxer] {.async, gcsafe.} =
## mux outgoing connection
conn: Connection,
direction: Direction): Future[Muxer] {.async, gcsafe.} =
## mux connection
trace "Muxing connection", conn
if self.muxers.len == 0:
warn "no muxers registered, skipping upgrade flow", conn
return
let muxerName = await self.ms.select(conn, self.muxers.mapIt(it.codec))
let muxerName =
if direction == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec))
if muxerName.len == 0 or muxerName == "na":
debug "no muxer available, early exit", conn
return
@@ -70,36 +59,23 @@ proc mux*(
# install stream handler
muxer.streamHandler = self.streamHandler
self.connManager.storeConn(conn)
# store it in muxed connections if we have a peer for it
self.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop
try:
await self.identify(muxer)
except CatchableError as exc:
# Identify is non-essential, though if it fails, it might indicate that
# the connection was closed already - this will be picked up by the read
# loop
debug "Could not identify connection", conn, msg = exc.msg
muxer.handler = muxer.handle()
return muxer
method upgradeOutgoing*(
method upgrade*(
self: MuxedUpgrade,
conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
trace "Upgrading outgoing connection", conn
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.async.} =
trace "Upgrading connection", conn, direction
let sconn = await self.secure(conn, peerId) # secure the connection
let sconn = await self.secure(conn, direction, peerId) # secure the connection
if isNil(sconn):
raise newException(UpgradeFailedError,
"unable to secure connection, stopping upgrade")
let muxer = await self.mux(sconn) # mux it if possible
let muxer = await self.mux(sconn, direction) # mux it if possible
if muxer == nil:
# TODO this might be relaxed in the future
raise newException(UpgradeFailedError,
"a muxer is required for outgoing connections")
@@ -111,108 +87,17 @@ method upgradeOutgoing*(
raise newException(UpgradeFailedError,
"Connection closed or missing peer info, stopping upgrade")
trace "Upgraded outgoing connection", conn, sconn
return sconn
method upgradeIncoming*(
self: MuxedUpgrade,
incomingConn: Connection) {.async, gcsafe.} = # noraises
trace "Upgrading incoming connection", incomingConn
let ms = MultistreamSelect.new()
# secure incoming connections
proc securedHandler(conn: Connection,
proto: string)
{.async, gcsafe, closure.} =
trace "Starting secure handler", conn
let secure = self.secureManagers.filterIt(it.codec == proto)[0]
var cconn = conn
try:
var sconn = await secure.secure(cconn, false, Opt.none(PeerId))
if isNil(sconn):
return
cconn = sconn
# add the muxer
for muxer in self.muxers:
ms.addHandler(muxer.codecs, muxer)
# handle subsequent secure requests
await ms.handle(cconn)
except CatchableError as exc:
debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn
if not cconn.isUpgraded:
cconn.upgrade(exc)
finally:
if not isNil(cconn):
await cconn.close()
trace "Stopped secure handler", conn
try:
if (await ms.select(incomingConn)): # just handshake
# add the secure handlers
for k in self.secureManagers:
ms.addHandler(k.codec, securedHandler)
# handle un-secured connections
# we handshaked above, set this ms handler as active
await ms.handle(incomingConn, active = true)
except CatchableError as exc:
debug "Exception upgrading incoming", exc = exc.msg
if not incomingConn.isUpgraded:
incomingConn.upgrade(exc)
finally:
if not isNil(incomingConn):
await incomingConn.close()
proc muxerHandler(
self: MuxedUpgrade,
muxer: Muxer) {.async, gcsafe.} =
let
conn = muxer.connection
# store incoming connection
self.connManager.storeConn(conn)
# store muxer and muxed connection
self.connManager.storeMuxer(muxer)
try:
await self.identify(muxer)
when defined(libp2p_agents_metrics):
#TODO Passing data between layers is a pain
if muxer.connection of SecureConn:
let secureConn = (SecureConn)muxer.connection
secureConn.stream.shortAgent = muxer.connection.shortAgent
except IdentifyError as exc:
# Identify is non-essential, though if it fails, it might indicate that
# the connection was closed already - this will be picked up by the read
# loop
debug "Could not identify connection", conn, msg = exc.msg
except LPStreamClosedError as exc:
debug "Identify stream closed", conn, msg = exc.msg
except LPStreamEOFError as exc:
debug "Identify stream EOF", conn, msg = exc.msg
except CancelledError as exc:
await muxer.close()
raise exc
except CatchableError as exc:
await muxer.close()
trace "Exception in muxer handler", conn, msg = exc.msg
trace "Upgraded connection", conn, sconn, direction
return muxer
proc new*(
T: type MuxedUpgrade,
identity: Identify,
muxers: seq[MuxerProvider],
secureManagers: openArray[Secure] = [],
connManager: ConnManager,
ms: MultistreamSelect): T =
let upgrader = T(
identity: identity,
muxers: muxers,
secureManagers: @secureManagers,
connManager: connManager,
@@ -231,10 +116,4 @@ proc new*(
await conn.closeWithEOF()
trace "Stream handler done", conn
for _, val in muxers:
val.streamHandler = upgrader.streamHandler
val.muxerHandler = proc(muxer: Muxer): Future[void]
{.raises: [Defect].} =
upgrader.muxerHandler(muxer)
return upgrader

View File

@@ -13,21 +13,22 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import std/[options, sequtils, strutils]
import std/[sequtils, strutils]
import pkg/[chronos, chronicles, metrics]
import ../stream/connection,
../protocols/secure/secure,
../protocols/identify,
../muxers/muxer,
../multistream,
../peerstore,
../connmanager,
../errors,
../utility
export connmanager, connection, identify, secure, multistream
declarePublicCounter(libp2p_failed_upgrade, "peers failed upgrade")
declarePublicCounter(libp2p_failed_upgrades_incoming, "incoming connections failed upgrades")
declarePublicCounter(libp2p_failed_upgrades_outgoing, "outgoing connections failed upgrades")
logScope:
topics = "libp2p upgrade"
@@ -37,29 +38,27 @@ type
Upgrade* = ref object of RootObj
ms*: MultistreamSelect
identity*: Identify
connManager*: ConnManager
secureManagers*: seq[Secure]
method upgradeIncoming*(
self: Upgrade,
conn: Connection): Future[void] {.base.} =
doAssert(false, "Not implemented!")
method upgradeOutgoing*(
method upgrade*(
self: Upgrade,
conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.base.} =
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base.} =
doAssert(false, "Not implemented!")
proc secure*(
self: Upgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
if self.secureManagers.len <= 0:
raise newException(UpgradeFailedError, "No secure managers registered!")
let codec = await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
let codec =
if direction == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec))
if codec.len == 0:
raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!")
@@ -70,30 +69,4 @@ proc secure*(
# let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0)
return await secureProtocol[0].secure(conn, true, peerId)
proc identify*(
self: Upgrade,
conn: Connection) {.async, gcsafe.} =
## identify the connection
if (await self.ms.select(conn, self.identity.codec)):
let
info = await self.identity.identify(conn, conn.peerId)
peerStore = self.connManager.peerStore
if info.pubkey.isNone and isNil(conn):
raise newException(UpgradeFailedError,
"no public key provided and no existing peer identity found")
conn.peerId = info.peerId
when defined(libp2p_agents_metrics):
conn.shortAgent = "unknown"
if info.agentVersion.isSome and info.agentVersion.get().len > 0:
let shortAgent = info.agentVersion.get().split("/")[0].safeToLowerAscii()
if shortAgent.isOk() and KnownLibP2PAgentsSeq.contains(shortAgent.get()):
conn.shortAgent = shortAgent.get()
peerStore.updatePeerInfo(info)
trace "identified remote peer", conn, peerId = shortLog(conn.peerId)
return await secureProtocol[0].secure(conn, direction == Out, peerId)

View File

@@ -19,6 +19,12 @@ template public* {.pragma.}
const
ShortDumpMax = 12
template compilesOr*(a, b: untyped): untyped =
when compiles(a):
a
else:
b
func shortLog*(item: openArray[byte]): string =
if item.len <= ShortDumpMax:
result = item.toHex()
@@ -57,5 +63,26 @@ when defined(libp2p_agents_metrics):
err("toLowerAscii failed")
const
KnownLibP2PAgents* {.strdefine.} = ""
KnownLibP2PAgents* {.strdefine.} = "nim-libp2p"
KnownLibP2PAgentsSeq* = KnownLibP2PAgents.safeToLowerAscii().tryGet().split(",")
template safeConvert*[T: SomeInteger, S: Ordinal](value: S): T =
## Converts `value` from S to `T` iff `value` is guaranteed to be preserved.
when int64(T.low) <= int64(S.low()) and uint64(T.high) >= uint64(S.high):
T(value)
else:
{.error: "Source and target types have an incompatible range low..high".}
template exceptionToAssert*(body: untyped): untyped =
block:
var res: type(body)
when defined(nimHasWarnBareExcept):
{.push warning[BareExcept]:off.}
try:
res = body
except CatchableError as exc: raise exc
except Defect as exc: raise exc
except Exception as exc: raiseAssert exc.msg
when defined(nimHasWarnBareExcept):
{.pop.}
res

View File

@@ -12,7 +12,6 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}
import sequtils
import chronos, chronicles
export chronicles

View File

@@ -14,7 +14,7 @@ else:
## This module implements wire network connection procedures.
import chronos, stew/endians2
import multiaddress, multicodec, errors
import multiaddress, multicodec, errors, utility
when defined(windows):
import winlean
@@ -23,14 +23,14 @@ else:
const
RTRANSPMA* = mapOr(
TCP,
WebSockets,
TCP_IP,
WebSockets_IP,
UNIX
)
TRANSPMA* = mapOr(
RTRANSPMA,
UDP
UDP_IP
)
@@ -71,12 +71,14 @@ proc initTAddress*(ma: MultiAddress): MaResult[TransportAddress] =
res.port = Port(fromBytesBE(uint16, pbuf))
ok(res)
else:
err("MultiAddress must be wire address (tcp, udp or unix)")
err("MultiAddress must be wire address (tcp, udp or unix): " & $ma)
proc connect*(
ma: MultiAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil): Future[StreamTransport]
child: StreamTransport = nil,
flags = default(set[SocketFlags]),
localAddress: Opt[MultiAddress] = Opt.none(MultiAddress)): Future[StreamTransport]
{.raises: [Defect, LPError, MaInvalidAddress].} =
## Open new connection to remote peer with address ``ma`` and create
## new transport object ``StreamTransport`` for established connection.
@@ -86,7 +88,15 @@ proc connect*(
if not(RTRANSPMA.match(ma)):
raise newException(MaInvalidAddress, "Incorrect or unsupported address!")
return connect(initTAddress(ma).tryGet(), bufferSize, child)
let transportAddress = initTAddress(ma).tryGet()
compilesOr:
return connect(transportAddress, bufferSize, child,
if localAddress.isSome(): initTAddress(localAddress.get()).tryGet() else : TransportAddress(),
flags)
do:
# support for older chronos versions
return connect(transportAddress, bufferSize, child)
proc createStreamServer*[T](ma: MultiAddress,
cbproc: StreamCallback,

View File

@@ -1,6 +1,6 @@
import unittest2
import unittest2, chronos
export unittest2
export unittest2, chronos
template asyncTeardown*(body: untyped): untyped =
teardown:

View File

@@ -1,13 +1,11 @@
{.used.}
import sequtils
import chronos, stew/[byteutils, results]
import ../libp2p/[stream/connection,
transports/transport,
upgrademngrs/upgrade,
multiaddress,
errors,
wire]
errors]
import ./helpers

View File

@@ -1,5 +1,5 @@
import ../config.nims
import strutils
import strutils, os
--threads:on
--d:metrics
@@ -8,18 +8,21 @@ import strutils
--d:libp2p_protobuf_metrics
--d:libp2p_network_protocols_metrics
--d:libp2p_mplex_metrics
--d:unittestPrintTime
--skipParentCfg
# Only add chronicles param if the
# user didn't specify any
var hasChroniclesParam = false
for param in 0..<paramCount():
if "chronicles" in paramStr(param):
for param in 0..<system.paramCount():
if "chronicles" in system.paramStr(param):
hasChroniclesParam = true
if hasChroniclesParam:
echo "Since you specified chronicles params, TRACE won't be tested!"
else:
switch("import", "stublogger")
let modulePath = currentSourcePath.parentDir / "stublogger"
switch("import", modulePath)
switch("define", "chronicles_sinks=textlines[stdout],json[dynamic]")
switch("define", "chronicles_log_level=TRACE")
switch("define", "chronicles_runtime_filtering=TRUE")

View File

@@ -4,6 +4,7 @@ else:
{.push raises: [].}
import chronos
import algorithm
import ../libp2p/transports/tcptransport
import ../libp2p/stream/bufferstream
@@ -55,9 +56,14 @@ template checkTrackers*() =
checkpoint tracker.dump()
fail()
# Also test the GC is not fooling with us
when defined(nimHasWarnBareExcept):
{.push warning[BareExcept]:off.}
try:
GC_fullCollect()
except: discard
except:
discard
when defined(nimHasWarnBareExcept):
{.pop.}
type RngWrap = object
rng: ref HmacDrbgContext
@@ -84,6 +90,8 @@ type
method write*(s: TestBufferStream, msg: seq[byte]): Future[void] =
s.writeHandler(msg)
method getWrapped*(s: TestBufferStream): Connection = nil
proc new*(T: typedesc[TestBufferStream], writeHandler: WriteHandler): T =
let testBufferStream = T(writeHandler: writeHandler)
testBufferStream.initStream()
@@ -117,3 +125,19 @@ proc checkExpiringInternal(cond: proc(): bool {.raises: [Defect], gcsafe.} ): Fu
template checkExpiring*(code: untyped): untyped =
check await checkExpiringInternal(proc(): bool = code)
proc unorderedCompare*[T](a, b: seq[T]): bool =
if a == b:
return true
if a.len != b.len:
return false
var aSorted = a
var bSorted = b
aSorted.sort()
bSorted.sort()
if aSorted == bSorted:
return true
return false

9
tests/pubsub/config.nims Normal file
View File

@@ -0,0 +1,9 @@
import strutils
proc hasSkipParentCfg: bool =
for param in 0..<paramCount():
if "skipParent" in paramStr(param):
return true
when hasSkipParentCfg():
import ../config.nims

View File

@@ -19,7 +19,9 @@ import utils,
protocols/pubsub/pubsub,
protocols/pubsub/floodsub,
protocols/pubsub/rpc/messages,
protocols/pubsub/peertable]
protocols/pubsub/peertable,
protocols/pubsub/pubsubpeer
]
import ../../libp2p/protocols/pubsub/errors as pubsub_errors
import ../helpers
@@ -62,6 +64,14 @@ suite "FloodSub":
check (await nodes[0].publish("foobar", "Hello!".toBytes())) > 0
check (await completionFut.wait(5.seconds)) == true
when defined(libp2p_agents_metrics):
let
agentA = nodes[0].peers[nodes[1].switch.peerInfo.peerId].shortAgent
agentB = nodes[1].peers[nodes[0].switch.peerInfo.peerId].shortAgent
check:
agentA == "nim-libp2p"
agentB == "nim-libp2p"
await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()

View File

@@ -2,13 +2,14 @@ include ../../libp2p/protocols/pubsub/gossipsub
{.used.}
import options
import std/[options, deques]
import stew/byteutils
import ../../libp2p/builders
import ../../libp2p/errors
import ../../libp2p/crypto/crypto
import ../../libp2p/stream/bufferstream
import ../../libp2p/switch
import ../../libp2p/muxers/muxer
import ../helpers
@@ -495,7 +496,7 @@ suite "GossipSub internal":
peer.handler = handler
peer.appScore = gossipSub.parameters.graylistThreshold - 1
gossipSub.gossipsub.mgetOrPut(topic, initHashSet[PubSubPeer]()).incl(peer)
gossipSub.switch.connManager.storeConn(conn)
gossipSub.switch.connManager.storeMuxer(Muxer(connection: conn))
gossipSub.updateScores()
@@ -712,6 +713,7 @@ suite "GossipSub internal":
let peer = gossipSub.getPubSubPeer(peerId)
let id = @[0'u8, 1, 2, 3]
gossipSub.mcache.put(id, Message())
peer.sentIHaves[^1].incl(id)
let msg = ControlIWant(
messageIDs: @[id, id, id]
)

View File

@@ -107,11 +107,7 @@ suite "GossipSub":
nodes[0].subscribe("foobar", handler)
nodes[1].subscribe("foobar", handler)
var subs: seq[Future[void]]
subs &= waitSub(nodes[1], nodes[0], "foobar")
subs &= waitSub(nodes[0], nodes[1], "foobar")
await allFuturesThrowing(subs)
await waitSubGraph(nodes, "foobar")
let gossip1 = GossipSub(nodes[0])
let gossip2 = GossipSub(nodes[1])
@@ -157,11 +153,7 @@ suite "GossipSub":
nodes[0].subscribe("foobar", handler)
nodes[1].subscribe("foobar", handler)
var subs: seq[Future[void]]
subs &= waitSub(nodes[1], nodes[0], "foobar")
subs &= waitSub(nodes[0], nodes[1], "foobar")
await allFuturesThrowing(subs)
await waitSubGraph(nodes, "foobar")
let gossip1 = GossipSub(nodes[0])
let gossip2 = GossipSub(nodes[1])
@@ -424,8 +416,6 @@ suite "GossipSub":
await passed.wait(2.seconds)
trace "test done, stopping..."
await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
@@ -452,21 +442,23 @@ suite "GossipSub":
nodes[1].switch.start(),
)
GossipSub(nodes[1]).parameters.d = 0
GossipSub(nodes[1]).parameters.dHigh = 0
GossipSub(nodes[1]).parameters.dLow = 0
await subscribeNodes(nodes)
nodes[1].subscribe("foobar", handler)
nodes[0].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar")
await waitSub(nodes[1], nodes[0], "foobar")
nodes[0].unsubscribe("foobar", handler)
nodes[1].subscribe("foobar", handler)
let gsNode = GossipSub(nodes[1])
checkExpiring: gsNode.mesh.getOrDefault("foobar").len == 0
nodes[0].subscribe("foobar", handler)
check GossipSub(nodes[0]).mesh.getOrDefault("foobar").len == 0
checkExpiring:
gsNode.mesh.getOrDefault("foobar").len == 0 and
GossipSub(nodes[0]).mesh.getOrDefault("foobar").len == 0 and
(
GossipSub(nodes[0]).gossipsub.getOrDefault("foobar").len == 1 or
GossipSub(nodes[0]).fanout.getOrDefault("foobar").len == 1
)
tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1
@@ -532,8 +524,8 @@ suite "GossipSub":
asyncTest "e2e - GossipSub should not send to source & peers who already seen":
# 3 nodes: A, B, C
# A publishes, B relays, C is having a long validation
# so C should not send to anyone
# A publishes, C relays, B is having a long validation
# so B should not send to anyone
let
nodes = generateNodes(
@@ -566,10 +558,7 @@ suite "GossipSub":
nodes[0].subscribe("foobar", handlerA)
nodes[1].subscribe("foobar", handlerB)
nodes[2].subscribe("foobar", handlerC)
await waitSub(nodes[0], nodes[1], "foobar")
await waitSub(nodes[0], nodes[2], "foobar")
await waitSub(nodes[2], nodes[1], "foobar")
await waitSub(nodes[1], nodes[2], "foobar")
await waitSubGraph(nodes, "foobar")
var gossip1: GossipSub = GossipSub(nodes[0])
var gossip2: GossipSub = GossipSub(nodes[1])
@@ -587,7 +576,11 @@ suite "GossipSub":
nodes[1].addValidator("foobar", slowValidator)
tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1
checkExpiring(
gossip1.mesh.getOrDefault("foobar").len == 2 and
gossip2.mesh.getOrDefault("foobar").len == 2 and
gossip3.mesh.getOrDefault("foobar").len == 2)
tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 2
await bFinished
@@ -629,7 +622,7 @@ suite "GossipSub":
tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1
check await passed
check await passed.wait(10.seconds)
check:
"foobar" in gossip1.gossipsub

View File

@@ -132,13 +132,17 @@ proc waitSubGraph*(nodes: seq[PubSub], key: string) {.async, gcsafe.} =
seen: HashSet[PeerId]
for n in nodes:
nodesMesh[n.peerInfo.peerId] = toSeq(GossipSub(n).mesh.getOrDefault(key).items()).mapIt(it.peerId)
proc explore(p: PeerId) =
if p in seen: return
seen.incl(p)
for peer in nodesMesh.getOrDefault(p):
explore(peer)
explore(nodes[0].peerInfo.peerId)
if seen.len == nodes.len: return
var ok = 0
for n in nodes:
seen.clear()
proc explore(p: PeerId) =
if p in seen: return
seen.incl(p)
for peer in nodesMesh.getOrDefault(p):
explore(peer)
explore(n.peerInfo.peerId)
if seen.len == nodes.len: ok.inc()
if ok == nodes.len: return
trace "waitSubGraph sleeping..."
await sleepAsync(5.milliseconds)

View File

@@ -1,12 +1,13 @@
import std/typetraits
import chronicles
when not defined(nimscript):
import std/typetraits
import chronicles
when defined(chronicles_runtime_filtering):
setLogLevel(INFO)
when defined(chronicles_runtime_filtering):
setLogLevel(INFO)
when defaultChroniclesStream.outputs.type.arity == 1:
# Hide the json logs, they're just here to check if we compile
proc noOutput(logLevel: LogLevel, msg: LogOutputStr) = discard
defaultChroniclesStream.outputs[0].writer = noOutput
when defaultChroniclesStream.outputs.type.arity == 1:
# Hide the json logs, they're just here to check if we compile
proc noOutput(logLevel: LogLevel, msg: LogOutputStr) = discard
defaultChroniclesStream.outputs[0].writer = noOutput
{.used.}

View File

@@ -19,6 +19,7 @@ import ../../libp2p/[protocols/connectivity/autonat/client,
peerid,
multiaddress,
switch]
from ../../libp2p/protocols/connectivity/autonat/core import NetworkReachability, AutonatUnreachableError, AutonatError
type
AutonatClientStub* = ref object of AutonatClient

View File

@@ -0,0 +1,48 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.used.}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos
import ../../libp2p/[peerid, multiaddress, switch]
type
SwitchStub* = ref object of Switch
switch*: Switch
connectStub*: proc(): Future[void] {.async.}
method connect*(
self: SwitchStub,
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out) {.async.} =
if (self.connectStub != nil):
await self.connectStub()
else:
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
proc new*(T: typedesc[SwitchStub], switch: Switch, connectStub: proc (): Future[void] {.async.} = nil): T =
return SwitchStub(
switch: switch,
peerInfo: switch.peerInfo,
ms: switch.ms,
transports: switch.transports,
connManager: switch.connManager,
peerStore: switch.peerStore,
dialer: switch.dialer,
nameResolver: switch.nameResolver,
services: switch.services,
connectStub: connectStub)

View File

@@ -1,3 +1,12 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.used.}
when (NimMajor, NimMinor) < (1, 4):

View File

@@ -1,3 +1,14 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/options
import chronos
import

View File

@@ -1,4 +1,6 @@
# Nim-LibP2P
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
@@ -7,7 +9,7 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/options
import std/[options, sequtils]
import chronos, metrics
import unittest2
import ../libp2p/[builders,
@@ -51,7 +53,7 @@ suite "Autonat Service":
let switch3 = createSwitch()
let switch4 = createSwitch()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
await switch1.start()
await switch2.start()
@@ -64,7 +66,7 @@ suite "Autonat Service":
await autonatClientStub.finished
check autonatService.networkReachability() == NetworkReachability.NotReachable
check autonatService.networkReachability == NetworkReachability.NotReachable
check libp2p_autonat_reachability_confidence.value(["NotReachable"]) == 0.3
await allFuturesThrowing(
@@ -86,7 +88,7 @@ suite "Autonat Service":
if not awaiter.finished:
awaiter.complete()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
autonatService.statusAndConfidenceHandler(statusAndConfidenceHandler)
@@ -101,12 +103,16 @@ suite "Autonat Service":
await awaiter
check autonatService.networkReachability() == NetworkReachability.Reachable
check autonatService.networkReachability == NetworkReachability.Reachable
check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 0.3
check switch1.peerInfo.addrs == switch1.peerInfo.listenAddrs.mapIt(switch1.peerStore.guessDialableAddr(it))
await allFuturesThrowing(
switch1.stop(), switch2.stop(), switch3.stop(), switch4.stop())
check switch1.peerInfo.addrs == switch1.peerInfo.listenAddrs
asyncTest "Peer must be not reachable and then reachable":
let autonatClientStub = AutonatClientStub.new(expectedDials = 6)
@@ -127,7 +133,7 @@ suite "Autonat Service":
autonatClientStub.answer = Reachable
awaiter.complete()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
autonatService.statusAndConfidenceHandler(statusAndConfidenceHandler)
@@ -142,12 +148,12 @@ suite "Autonat Service":
await awaiter
check autonatService.networkReachability() == NetworkReachability.NotReachable
check autonatService.networkReachability == NetworkReachability.NotReachable
check libp2p_autonat_reachability_confidence.value(["NotReachable"]) == 0.3
await autonatClientStub.finished
check autonatService.networkReachability() == NetworkReachability.Reachable
check autonatService.networkReachability == NetworkReachability.Reachable
check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 0.3
await allFuturesThrowing(switch1.stop(), switch2.stop(), switch3.stop(), switch4.stop())
@@ -168,7 +174,7 @@ suite "Autonat Service":
if not awaiter.finished:
awaiter.complete()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
autonatService.statusAndConfidenceHandler(statusAndConfidenceHandler)
@@ -183,7 +189,7 @@ suite "Autonat Service":
await awaiter
check autonatService.networkReachability() == NetworkReachability.Reachable
check autonatService.networkReachability == NetworkReachability.Reachable
check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 1
await allFuturesThrowing(
@@ -209,7 +215,7 @@ suite "Autonat Service":
autonatClientStub.answer = Unknown
awaiter.complete()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
autonatService.statusAndConfidenceHandler(statusAndConfidenceHandler)
@@ -224,12 +230,12 @@ suite "Autonat Service":
await awaiter
check autonatService.networkReachability() == NetworkReachability.NotReachable
check autonatService.networkReachability == NetworkReachability.NotReachable
check libp2p_autonat_reachability_confidence.value(["NotReachable"]) == 1/3
await autonatClientStub.finished
check autonatService.networkReachability() == NetworkReachability.NotReachable
check autonatService.networkReachability == NetworkReachability.NotReachable
check libp2p_autonat_reachability_confidence.value(["NotReachable"]) == 1/3
await allFuturesThrowing(switch1.stop(), switch2.stop(), switch3.stop(), switch4.stop())
@@ -260,7 +266,7 @@ suite "Autonat Service":
if not awaiter.finished:
awaiter.complete()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
autonatService.statusAndConfidenceHandler(statusAndConfidenceHandler)
@@ -271,12 +277,100 @@ suite "Autonat Service":
await awaiter
check autonatService.networkReachability() == NetworkReachability.Reachable
check autonatService.networkReachability == NetworkReachability.Reachable
check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 1
await allFuturesThrowing(
switch1.stop(), switch2.stop())
asyncTest "Must work when peers ask each other at the same time with max 1 conn per peer":
let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3)
let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3)
let autonatService3 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3)
let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0)
let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0)
let switch3 = createSwitch(autonatService2, maxConnsPerPeer = 0)
let awaiter1 = newFuture[void]()
let awaiter2 = newFuture[void]()
let awaiter3 = newFuture[void]()
proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter1.finished:
awaiter1.complete()
proc statusAndConfidenceHandler2(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter2.finished:
awaiter2.complete()
check autonatService1.networkReachability == NetworkReachability.Unknown
check autonatService2.networkReachability == NetworkReachability.Unknown
autonatService1.statusAndConfidenceHandler(statusAndConfidenceHandler1)
autonatService2.statusAndConfidenceHandler(statusAndConfidenceHandler2)
await switch1.start()
await switch2.start()
await switch3.start()
await switch1.connect(switch2.peerInfo.peerId, switch2.peerInfo.addrs)
await switch2.connect(switch1.peerInfo.peerId, switch1.peerInfo.addrs)
await switch2.connect(switch3.peerInfo.peerId, switch3.peerInfo.addrs)
await awaiter1
await awaiter2
check autonatService1.networkReachability == NetworkReachability.Reachable
check autonatService2.networkReachability == NetworkReachability.Reachable
check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 1
await allFuturesThrowing(
switch1.stop(), switch2.stop(), switch3.stop())
asyncTest "Must work for one peer when two peers ask each other at the same time with max 1 conn per peer":
let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3)
let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3)
let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0)
let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0)
let awaiter1 = newFuture[void]()
proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter1.finished:
awaiter1.complete()
check autonatService1.networkReachability == NetworkReachability.Unknown
autonatService1.statusAndConfidenceHandler(statusAndConfidenceHandler1)
await switch1.start()
await switch2.start()
await switch1.connect(switch2.peerInfo.peerId, switch2.peerInfo.addrs)
try:
# We allow a temp conn for the peer to dial us. It could use this conn to just connect to us and not dial.
# We don't care if it fails at this point or not. But this conn must be closed eventually.
# Bellow we check that there's only one connection between the peers
await switch2.connect(switch1.peerInfo.peerId, switch1.peerInfo.addrs, reuseConnection = false)
except CatchableError:
discard
await awaiter1
check autonatService1.networkReachability == NetworkReachability.Reachable
check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 1
# Make sure remote peer can't create a connection to us
check switch1.connManager.connCount(switch2.peerInfo.peerId) == 1
await allFuturesThrowing(
switch1.stop(), switch2.stop())
asyncTest "Must work with low maxConnections":
let autonatService = AutonatService.new(AutonatClient.new(), newRng(), some(1.seconds), maxQueueSize = 1)
@@ -293,7 +387,7 @@ suite "Autonat Service":
if not awaiter.finished:
awaiter.complete()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
autonatService.statusAndConfidenceHandler(statusAndConfidenceHandler)
@@ -315,7 +409,7 @@ suite "Autonat Service":
await autonatService.run(switch1)
await awaiter
check autonatService.networkReachability() == NetworkReachability.Reachable
check autonatService.networkReachability == NetworkReachability.Reachable
check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 1
await allFuturesThrowing(
@@ -330,7 +424,7 @@ suite "Autonat Service":
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} =
fail()
check autonatService.networkReachability() == NetworkReachability.Unknown
check autonatService.networkReachability == NetworkReachability.Unknown
autonatService.statusAndConfidenceHandler(statusAndConfidenceHandler)

View File

@@ -1,15 +1,21 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos, options
import ../libp2p
import ../libp2p/[crypto/crypto,
protocols/connectivity/relay/relay,
protocols/connectivity/relay/messages,
protocols/connectivity/relay/utils,
protocols/connectivity/relay/client,
services/autorelayservice]
import ./helpers
import stew/byteutils
proc createSwitch(r: Relay, autorelay: Service = nil): Switch =
var builder = SwitchBuilder.new()
@@ -72,10 +78,13 @@ suite "Autorelay":
await fut.wait(1.seconds)
let addresses = autorelay.getAddresses()
check:
addresses[0] == buildRelayMA(switchRelay, switchClient)
addresses == @[buildRelayMA(switchRelay, switchClient)]
addresses.len() == 1
addresses == switchClient.peerInfo.addrs
await allFutures(switchClient.stop(), switchRelay.stop())
check addresses != switchClient.peerInfo.addrs
asyncTest "Three relays connections":
var state = 0
let

View File

@@ -1,3 +1,14 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos, stew/byteutils
import ../libp2p/stream/bufferstream,
../libp2p/stream/lpstream,
@@ -5,8 +16,6 @@ import ../libp2p/stream/bufferstream,
import ./helpers
{.used.}
suite "BufferStream":
teardown:
# echo getTracker(BufferStreamTrackerName).dump()

View File

@@ -1,8 +1,17 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import unittest2
import ../libp2p/[cid, multihash, multicodec]
when defined(nimHasUsed): {.used.}
suite "Content identifier CID test suite":
test "CIDv0 test vector":

View File

@@ -1,4 +1,15 @@
import chronos, nimcrypto/utils
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos
import ../libp2p/[stream/connection,
stream/bufferstream]

View File

@@ -1,4 +1,15 @@
import sequtils
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/[sequtils,tables]
import stew/results
import chronos
import ../libp2p/[connmanager,
@@ -10,8 +21,8 @@ import ../libp2p/[connmanager,
import helpers
proc getConnection(peerId: PeerId, dir: Direction = Direction.In): Connection =
return Connection.new(peerId, dir, Opt.none(MultiAddress))
proc getMuxer(peerId: PeerId, dir: Direction = Direction.In): Muxer =
return Muxer(connection: Connection.new(peerId, dir, Opt.none(MultiAddress)))
type
TestMuxer = ref object of Muxer
@@ -22,71 +33,68 @@ method newStream*(
name: string = "",
lazy: bool = false):
Future[Connection] {.async, gcsafe.} =
result = getConnection(m.peerId, Direction.Out)
result = Connection.new(m.peerId, Direction.Out, Opt.none(MultiAddress))
suite "Connection Manager":
teardown:
checkTrackers()
asyncTest "add and retrieve a connection":
asyncTest "add and retrieve a muxer":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
let mux = getMuxer(peerId)
connMngr.storeConn(conn)
check conn in connMngr
connMngr.storeMuxer(mux)
check mux in connMngr
let peerConn = connMngr.selectConn(peerId)
check peerConn == conn
check peerConn.dir == Direction.In
let peerMux = connMngr.selectMuxer(peerId)
check peerMux == mux
check peerMux.connection.dir == Direction.In
await connMngr.close()
asyncTest "get all connections":
let connMngr = ConnManager.new()
let peers = toSeq(0..<2).mapIt(PeerId.random.tryGet())
let muxs = toSeq(0..<2).mapIt(getMuxer(peers[it]))
for mux in muxs: connMngr.storeMuxer(mux)
let conns = connMngr.getConnections()
let connsMux = toSeq(conns.values).mapIt(it[0])
check unorderedCompare(connsMux, muxs)
await connMngr.close()
asyncTest "shouldn't allow a closed connection":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
await conn.close()
let mux = getMuxer(peerId)
await mux.connection.close()
expect CatchableError:
connMngr.storeConn(conn)
connMngr.storeMuxer(mux)
await connMngr.close()
asyncTest "shouldn't allow an EOFed connection":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
conn.isEof = true
let mux = getMuxer(peerId)
mux.connection.isEof = true
expect CatchableError:
connMngr.storeConn(conn)
connMngr.storeMuxer(mux)
await conn.close()
await mux.close()
await connMngr.close()
asyncTest "add and retrieve a muxer":
asyncTest "shouldn't allow a muxer with no connection":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
let muxer = new Muxer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
let peerMuxer = connMngr.selectMuxer(conn)
check peerMuxer == muxer
await connMngr.close()
asyncTest "shouldn't allow a muxer for an untracked connection":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
let muxer = new Muxer
muxer.connection = conn
let muxer = getMuxer(peerId)
let conn = muxer.connection
muxer.connection = nil
expect CatchableError:
connMngr.storeMuxer(muxer)
@@ -99,33 +107,34 @@ suite "Connection Manager":
# This would work with 1 as well cause of a bug in connmanager that will get fixed soon
let connMngr = ConnManager.new(maxConnsPerPeer = 2)
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn1 = getConnection(peerId, Direction.Out)
let conn2 = getConnection(peerId)
let mux1 = getMuxer(peerId, Direction.Out)
let mux2 = getMuxer(peerId)
connMngr.storeConn(conn1)
connMngr.storeConn(conn2)
check conn1 in connMngr
check conn2 in connMngr
connMngr.storeMuxer(mux1)
connMngr.storeMuxer(mux2)
check mux1 in connMngr
check mux2 in connMngr
let outConn = connMngr.selectConn(peerId, Direction.Out)
let inConn = connMngr.selectConn(peerId, Direction.In)
let outMux = connMngr.selectMuxer(peerId, Direction.Out)
let inMux = connMngr.selectMuxer(peerId, Direction.In)
check outConn != inConn
check outConn.dir == Direction.Out
check inConn.dir == Direction.In
check outMux != inMux
check outMux == mux1
check inMux == mux2
check outMux.connection.dir == Direction.Out
check inMux.connection.dir == Direction.In
await connMngr.close()
asyncTest "get muxed stream for peer":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
let muxer = new TestMuxer
let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress))
muxer.peerId = peerId
muxer.connection = conn
muxer.connection = connection
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
@@ -134,18 +143,18 @@ suite "Connection Manager":
check stream.peerId == peerId
await connMngr.close()
await connection.close()
await stream.close()
asyncTest "get stream from directed connection":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
let muxer = new TestMuxer
let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress))
muxer.peerId = peerId
muxer.connection = conn
muxer.connection = connection
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
@@ -156,99 +165,73 @@ suite "Connection Manager":
await connMngr.close()
await stream1.close()
asyncTest "get stream from any connection":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
let muxer = new TestMuxer
muxer.peerId = peerId
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
let stream = await connMngr.getStream(conn)
check not(isNil(stream))
await connMngr.close()
await stream.close()
await connection.close()
asyncTest "should raise on too many connections":
let connMngr = ConnManager.new(maxConnsPerPeer = 0)
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
connMngr.storeConn(getConnection(peerId))
connMngr.storeMuxer(getMuxer(peerId))
let conns = @[
getConnection(peerId),
getConnection(peerId)]
let muxs = @[getMuxer(peerId)]
expect TooManyConnectionsError:
connMngr.storeConn(conns[0])
connMngr.storeMuxer(muxs[0])
await connMngr.close()
await allFuturesThrowing(
allFutures(conns.mapIt( it.close() )))
allFutures(muxs.mapIt( it.close() )))
asyncTest "expect connection from peer":
# FIXME This should be 1 instead of 0, it will get fixed soon
let connMngr = ConnManager.new(maxConnsPerPeer = 0)
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
connMngr.storeConn(getConnection(peerId))
connMngr.storeMuxer(getMuxer(peerId))
let conns = @[
getConnection(peerId),
getConnection(peerId)]
let muxs = @[
getMuxer(peerId),
getMuxer(peerId)]
expect TooManyConnectionsError:
connMngr.storeConn(conns[0])
connMngr.storeMuxer(muxs[0])
let waitedConn1 = connMngr.expectConnection(peerId)
let waitedConn1 = connMngr.expectConnection(peerId, In)
expect LPError:
discard await connMngr.expectConnection(peerId)
expect AlreadyExpectingConnectionError:
discard await connMngr.expectConnection(peerId, In)
await waitedConn1.cancelAndWait()
let
waitedConn2 = connMngr.expectConnection(peerId)
waitedConn3 = connMngr.expectConnection(PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet())
conn = getConnection(peerId)
connMngr.storeConn(conn)
waitedConn2 = connMngr.expectConnection(peerId, In)
waitedConn3 = connMngr.expectConnection(PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), In)
conn = getMuxer(peerId)
connMngr.storeMuxer(conn)
check (await waitedConn2) == conn
expect TooManyConnectionsError:
connMngr.storeConn(conns[1])
connMngr.storeMuxer(muxs[1])
await connMngr.close()
checkExpiring: waitedConn3.cancelled()
await allFuturesThrowing(
allFutures(conns.mapIt( it.close() )))
allFutures(muxs.mapIt( it.close() )))
asyncTest "cleanup on connection close":
let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let conn = getConnection(peerId)
let muxer = new Muxer
let muxer = getMuxer(peerId)
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check conn in connMngr
check muxer in connMngr
await conn.close()
await sleepAsync(10.millis)
await muxer.close()
check conn notin connMngr
check muxer notin connMngr
checkExpiring: muxer notin connMngr
await connMngr.close()
@@ -261,23 +244,19 @@ suite "Connection Manager":
Direction.In else:
Direction.Out
let conn = getConnection(peerId, dir)
let muxer = new Muxer
muxer.connection = conn
let muxer = getMuxer(peerId, dir)
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check conn in connMngr
check muxer in connMngr
check not(isNil(connMngr.selectConn(peerId, dir)))
check not(isNil(connMngr.selectMuxer(peerId, dir)))
check peerId in connMngr
await connMngr.dropPeer(peerId)
check peerId notin connMngr
check isNil(connMngr.selectConn(peerId, Direction.In))
check isNil(connMngr.selectConn(peerId, Direction.Out))
checkExpiring: peerId notin connMngr
check isNil(connMngr.selectMuxer(peerId, Direction.In))
check isNil(connMngr.selectMuxer(peerId, Direction.Out))
await connMngr.close()
@@ -363,7 +342,6 @@ suite "Connection Manager":
asyncTest "track incoming max connections limits - fail on outgoing":
let connMngr = ConnManager.new(maxIn = 3)
var conns: seq[Connection]
for i in 0..<3:
check await connMngr.getIncomingSlot().withTimeout(10.millis)
@@ -376,7 +354,6 @@ suite "Connection Manager":
asyncTest "allow force dial":
let connMngr = ConnManager.new(maxConnections = 2)
var conns: seq[Connection]
for i in 0..<3:
discard connMngr.getOutgoingSlot(true)
@@ -389,17 +366,17 @@ suite "Connection Manager":
asyncTest "release slot on connection end":
let connMngr = ConnManager.new(maxConnections = 3)
var conns: seq[Connection]
var muxs: seq[Muxer]
for i in 0..<3:
let slot = connMngr.getOutgoingSlot()
let conn =
getConnection(
let muxer =
getMuxer(
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
Direction.In)
slot.trackConnection(conn)
conns.add(conn)
slot.trackMuxer(muxer)
muxs.add(muxer)
# should be full now
let incomingSlot = connMngr.getIncomingSlot()
@@ -407,7 +384,7 @@ suite "Connection Manager":
check (await incomingSlot.withTimeout(10.millis)) == false
await allFuturesThrowing(
allFutures(conns.mapIt( it.close() )))
allFutures(muxs.mapIt( it.close() )))
check await incomingSlot.withTimeout(10.millis)

View File

@@ -1,3 +1,5 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
@@ -11,11 +13,9 @@
## https://github.com/libp2p/go-libp2p-crypto/blob/master/key.go
import unittest2
import bearssl/hash
import nimcrypto/[utils, sysrand]
import nimcrypto/utils
import ../libp2p/crypto/[crypto, chacha20poly1305, curve25519, hkdf]
when defined(nimHasUsed): {.used.}
const
PrivateKeys = [
"""080012A809308204A40201000282010100C8B014EC01E135D635F7E246BA7D42

157
tests/testdcutr.nim Normal file
View File

@@ -0,0 +1,157 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos
import unittest2
import ../libp2p/protocols/connectivity/dcutr/core as dcore
import ../libp2p/protocols/connectivity/dcutr/[client, server]
from ../libp2p/protocols/connectivity/autonat/core import NetworkReachability
import ../libp2p/builders
import ../libp2p/utils/future
import ./helpers
import ./stubs/switchstub
suite "Dcutr":
teardown:
checkTrackers()
asyncTest "Connect msg Encode / Decode":
let addrs = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet(), MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
let connectMsg = DcutrMsg(msgType: MsgType.Connect, addrs: addrs)
let pb = connectMsg.encode()
let connectMsgDecoded = DcutrMsg.decode(pb.buffer)
check connectMsg == connectMsgDecoded
asyncTest "Sync msg Encode / Decode":
let addrs = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet(), MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
let syncMsg = DcutrMsg(msgType: MsgType.Sync, addrs: addrs)
let pb = syncMsg.encode()
let syncMsgDecoded = DcutrMsg.decode(pb.buffer)
check syncMsg == syncMsgDecoded
asyncTest "DCUtR establishes a new connection":
let behindNATSwitch = newStandardSwitch()
let publicSwitch = newStandardSwitch()
let dcutrProto = Dcutr.new(publicSwitch)
publicSwitch.mount(dcutrProto)
await allFutures(behindNATSwitch.start(), publicSwitch.start())
await publicSwitch.connect(behindNATSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs)
for t in behindNATSwitch.transports:
t.networkReachability = NetworkReachability.NotReachable
await DcutrClient.new().startSync(behindNATSwitch, publicSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs)
.wait(300.millis)
checkExpiring:
# we can't hole punch when both peers are in the same machine. This means that the simultaneous dialings will result
# in two connections attemps, instead of one. The server dial is going to fail because it is acting as the
# tcp simultaneous incoming upgrader in the dialer which works only in the simultaneous open case, but the client
# dial will succeed.
behindNATSwitch.connManager.connCount(publicSwitch.peerInfo.peerId) == 2
await allFutures(behindNATSwitch.stop(), publicSwitch.stop())
template ductrClientTest(behindNATSwitch: Switch, publicSwitch: Switch, body: untyped) =
let dcutrProto = Dcutr.new(publicSwitch)
publicSwitch.mount(dcutrProto)
await allFutures(behindNATSwitch.start(), publicSwitch.start())
await publicSwitch.connect(behindNATSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs)
for t in behindNATSwitch.transports:
t.networkReachability = NetworkReachability.NotReachable
body
checkExpiring:
# no connection will be open by the receiver peer acting as the dcutr server
behindNATSwitch.connManager.connCount(publicSwitch.peerInfo.peerId) == 1
await allFutures(behindNATSwitch.stop(), publicSwitch.stop())
asyncTest "Client connect timeout":
proc connectTimeoutProc(): Future[void] {.async.} =
await sleepAsync(100.millis)
let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectTimeoutProc)
let publicSwitch = newStandardSwitch()
ductrClientTest(behindNATSwitch, publicSwitch):
try:
let client = DcutrClient.new(connectTimeout = 5.millis)
await client.startSync(behindNATSwitch, publicSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs)
except DcutrError as err:
check err.parent of AsyncTimeoutError
asyncTest "All client connect attempts fail":
proc connectErrorProc(): Future[void] {.async.} =
raise newException(CatchableError, "error")
let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectErrorProc)
let publicSwitch = newStandardSwitch()
ductrClientTest(behindNATSwitch, publicSwitch):
try:
let client = DcutrClient.new(connectTimeout = 5.millis)
await client.startSync(behindNATSwitch, publicSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs)
except DcutrError as err:
check err.parent of AllFuturesFailedError
proc ductrServerTest(connectStub: proc (): Future[void] {.async.}) {.async.} =
let behindNATSwitch = newStandardSwitch()
let publicSwitch = SwitchStub.new(newStandardSwitch())
let dcutrProto = Dcutr.new(publicSwitch, connectTimeout = 5.millis)
publicSwitch.mount(dcutrProto)
await allFutures(behindNATSwitch.start(), publicSwitch.start())
await publicSwitch.connect(behindNATSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs)
publicSwitch.connectStub = connectStub
for t in behindNATSwitch.transports:
t.networkReachability = NetworkReachability.NotReachable
await DcutrClient.new().startSync(behindNATSwitch, publicSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs)
.wait(300.millis)
checkExpiring:
# we can't hole punch when both peers are in the same machine. This means that the simultaneous dialings will result
# in two connections attemps, instead of one. The server dial is going to fail, but the client dial will succeed.
behindNATSwitch.connManager.connCount(publicSwitch.peerInfo.peerId) == 2
await allFutures(behindNATSwitch.stop(), publicSwitch.stop())
asyncTest "DCUtR server timeout when establishing a new connection":
proc connectProc(): Future[void] {.async.} =
await sleepAsync(100.millis)
await ductrServerTest(connectProc)
asyncTest "DCUtR server error when establishing a new connection":
proc connectProc(): Future[void] {.async.} =
raise newException(CatchableError, "error")
await ductrServerTest(connectProc)

View File

@@ -1,7 +1,15 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import options, chronos, sets
import stew/byteutils
import ../libp2p/[protocols/rendezvous,
switch,
builders,

View File

@@ -1,3 +1,5 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
@@ -6,13 +8,12 @@
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import unittest2
import nimcrypto/utils
import ../libp2p/crypto/[crypto, ecnist]
import stew/results
when defined(nimHasUsed): {.used.}
const
TestsCount = 10 # number of random tests

View File

@@ -1,3 +1,5 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
@@ -13,8 +15,6 @@ import nimcrypto/utils
import ../libp2p/crypto/crypto
import ../libp2p/crypto/ed25519/ed25519
when defined(nimHasUsed): {.used.}
const TestsCount = 20
const SecretKeys = [

View File

@@ -1,3 +1,5 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of

View File

@@ -1,11 +1,22 @@
import chronos
{.used.}
import ../libp2p/utils/heartbeat
import ./helpers
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
# MacOs has some nasty jitter when sleeping
# (up to 7 ms), so we skip test there
when not defined(macosx):
import chronos
import ../libp2p/utils/heartbeat
import ./helpers
suite "Heartbeat":
asyncTest "simple heartbeat":

219
tests/testhpservice.nim Normal file
View File

@@ -0,0 +1,219 @@
# Nim-LibP2P
# Copyright (c) 2022 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.used.}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos
import unittest2
import ./helpers
import ./stubs/switchstub
import ../libp2p/[builders,
switch,
services/hpservice,
services/autorelayservice]
import ../libp2p/protocols/connectivity/relay/[relay, client]
import ../libp2p/protocols/connectivity/autonat/[service]
import ../libp2p/nameresolving/nameresolver
import ../libp2p/nameresolving/mockresolver
import stubs/autonatclientstub
proc isPublicAddrIPAddrMock(ta: TransportAddress): bool =
return true
proc createSwitch(r: Relay = nil, hpService: Service = nil, nameResolver: NameResolver = nil): Switch {.raises: [LPError, Defect].} =
var builder = SwitchBuilder.new()
.withRng(newRng())
.withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ])
.withTcpTransport()
.withMplex()
.withAutonat()
.withNoise()
if hpService != nil:
builder = builder.withServices(@[hpService])
if r != nil:
builder = builder.withCircuitRelay(r)
if nameResolver != nil:
builder = builder.withNameResolver(nameResolver)
return builder.build()
proc buildRelayMA(switchRelay: Switch, switchClient: Switch): MultiAddress =
MultiAddress.init($switchRelay.peerInfo.addrs[0] & "/p2p/" &
$switchRelay.peerInfo.peerId & "/p2p-circuit/p2p/" &
$switchClient.peerInfo.peerId).get()
suite "Hole Punching":
teardown:
checkTrackers()
asyncTest "Direct connection must work when peer address is public":
let autonatClientStub = AutonatClientStub.new(expectedDials = 1)
autonatClientStub.answer = NotReachable
let autonatService = AutonatService.new(autonatClientStub, newRng(), maxQueueSize = 1)
let relayClient = RelayClient.new()
let privatePeerRelayAddr = newFuture[seq[MultiAddress]]()
let publicPeerSwitch = createSwitch(RelayClient.new())
proc checkMA(address: seq[MultiAddress]) =
if not privatePeerRelayAddr.completed():
privatePeerRelayAddr.complete(address)
let autoRelayService = AutoRelayService.new(1, relayClient, checkMA, newRng())
let hpservice = HPService.new(autonatService, autoRelayService, isPublicAddrIPAddrMock)
let privatePeerSwitch = createSwitch(relayClient, hpservice)
let peerSwitch = createSwitch()
let switchRelay = createSwitch(Relay.new())
await allFutures(switchRelay.start(), privatePeerSwitch.start(), publicPeerSwitch.start(), peerSwitch.start())
await privatePeerSwitch.connect(switchRelay.peerInfo.peerId, switchRelay.peerInfo.addrs)
await privatePeerSwitch.connect(peerSwitch.peerInfo.peerId, peerSwitch.peerInfo.addrs) # for autonat
await publicPeerSwitch.connect(privatePeerSwitch.peerInfo.peerId, (await privatePeerRelayAddr))
checkExpiring:
privatePeerSwitch.connManager.connCount(publicPeerSwitch.peerInfo.peerId) == 1 and
not isRelayed(privatePeerSwitch.connManager.selectMuxer(publicPeerSwitch.peerInfo.peerId).connection)
await allFuturesThrowing(
privatePeerSwitch.stop(), publicPeerSwitch.stop(), switchRelay.stop(), peerSwitch.stop())
asyncTest "Direct connection must work when peer address is public and dns is used":
let autonatClientStub = AutonatClientStub.new(expectedDials = 1)
autonatClientStub.answer = NotReachable
let autonatService = AutonatService.new(autonatClientStub, newRng(), maxQueueSize = 1)
let relayClient = RelayClient.new()
let privatePeerRelayAddr = newFuture[seq[MultiAddress]]()
let resolver = MockResolver.new()
resolver.ipResponses[("localhost", false)] = @["127.0.0.1"]
resolver.ipResponses[("localhost", true)] = @["::1"]
let publicPeerSwitch = createSwitch(RelayClient.new(), nameResolver = resolver)
proc addressMapper(listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]] {.gcsafe, async.} =
return @[MultiAddress.init("/dns4/localhost/").tryGet() & listenAddrs[0][1].tryGet()]
publicPeerSwitch.peerInfo.addressMappers.add(addressMapper)
await publicPeerSwitch.peerInfo.update()
proc checkMA(address: seq[MultiAddress]) =
if not privatePeerRelayAddr.completed():
privatePeerRelayAddr.complete(address)
let autoRelayService = AutoRelayService.new(1, relayClient, checkMA, newRng())
let hpservice = HPService.new(autonatService, autoRelayService, isPublicAddrIPAddrMock)
let privatePeerSwitch = createSwitch(relayClient, hpservice, nameResolver = resolver)
let peerSwitch = createSwitch()
let switchRelay = createSwitch(Relay.new())
await allFutures(switchRelay.start(), privatePeerSwitch.start(), publicPeerSwitch.start(), peerSwitch.start())
await privatePeerSwitch.connect(switchRelay.peerInfo.peerId, switchRelay.peerInfo.addrs)
await privatePeerSwitch.connect(peerSwitch.peerInfo.peerId, peerSwitch.peerInfo.addrs) # for autonat
await publicPeerSwitch.connect(privatePeerSwitch.peerInfo.peerId, (await privatePeerRelayAddr))
checkExpiring:
privatePeerSwitch.connManager.connCount(publicPeerSwitch.peerInfo.peerId) == 1 and
not isRelayed(privatePeerSwitch.connManager.selectMuxer(publicPeerSwitch.peerInfo.peerId).connection)
await allFuturesThrowing(
privatePeerSwitch.stop(), publicPeerSwitch.stop(), switchRelay.stop(), peerSwitch.stop())
proc holePunchingTest(connectStub: proc (): Future[void] {.async.},
isPublicIPAddrProc: IsPublicIPAddrProc,
answer: Answer) {.async.} =
# There's no check in this test cause it can't test hole punching locally. It exists just to be sure the rest of
# the code works properly.
let autonatClientStub1 = AutonatClientStub.new(expectedDials = 1)
autonatClientStub1.answer = NotReachable
let autonatService1 = AutonatService.new(autonatClientStub1, newRng(), maxQueueSize = 1)
let autonatClientStub2 = AutonatClientStub.new(expectedDials = 1)
autonatClientStub2.answer = answer
let autonatService2 = AutonatService.new(autonatClientStub2, newRng(), maxQueueSize = 1)
let relayClient1 = RelayClient.new()
let relayClient2 = RelayClient.new()
let privatePeerRelayAddr1 = newFuture[seq[MultiAddress]]()
proc checkMA(address: seq[MultiAddress]) =
if not privatePeerRelayAddr1.completed():
privatePeerRelayAddr1.complete(address)
let autoRelayService1 = AutoRelayService.new(1, relayClient1, checkMA, newRng())
let autoRelayService2 = AutoRelayService.new(1, relayClient2, nil, newRng())
let hpservice1 = HPService.new(autonatService1, autoRelayService1, isPublicIPAddrProc)
let hpservice2 = HPService.new(autonatService2, autoRelayService2)
let privatePeerSwitch1 = SwitchStub.new(createSwitch(relayClient1, hpservice1))
let privatePeerSwitch2 = createSwitch(relayClient2, hpservice2)
let switchRelay = createSwitch(Relay.new())
let switchAux = createSwitch()
let switchAux2 = createSwitch()
let switchAux3 = createSwitch()
let switchAux4 = createSwitch()
var awaiter = newFuture[void]()
await allFutures(
switchRelay.start(), privatePeerSwitch1.start(), privatePeerSwitch2.start(),
switchAux.start(), switchAux2.start(), switchAux3.start(), switchAux4.start()
)
await privatePeerSwitch1.connect(switchRelay.peerInfo.peerId, switchRelay.peerInfo.addrs)
await privatePeerSwitch2.connect(switchAux.peerInfo.peerId, switchAux.peerInfo.addrs)
await sleepAsync(200.millis)
await privatePeerSwitch1.connect(switchAux2.peerInfo.peerId, switchAux2.peerInfo.addrs)
await privatePeerSwitch1.connect(switchAux3.peerInfo.peerId, switchAux3.peerInfo.addrs)
await privatePeerSwitch1.connect(switchAux4.peerInfo.peerId, switchAux4.peerInfo.addrs)
await privatePeerSwitch2.connect(switchAux2.peerInfo.peerId, switchAux2.peerInfo.addrs)
await privatePeerSwitch2.connect(switchAux3.peerInfo.peerId, switchAux3.peerInfo.addrs)
await privatePeerSwitch2.connect(switchAux4.peerInfo.peerId, switchAux4.peerInfo.addrs)
privatePeerSwitch1.connectStub = connectStub
await privatePeerSwitch2.connect(privatePeerSwitch1.peerInfo.peerId, (await privatePeerRelayAddr1))
await sleepAsync(200.millis)
await allFuturesThrowing(
privatePeerSwitch1.stop(), privatePeerSwitch2.stop(), switchRelay.stop(),
switchAux.stop(), switchAux2.stop(), switchAux3.stop(), switchAux4.stop())
asyncTest "Hole punching when peers addresses are private":
await holePunchingTest(nil, isGlobal, NotReachable)
asyncTest "Hole punching when there is an error during unilateral direct connection":
proc connectStub(): Future[void] {.async.} =
raise newException(CatchableError, "error")
await holePunchingTest(connectStub, isPublicAddrIPAddrMock, Reachable)

Some files were not shown because too many files have changed in this diff Show More