Compare commits

...

43 Commits

Author SHA1 Message Date
vladopajic
8b0565725f Merge branch 'master' into no-splitting 2025-09-08 16:53:32 +02:00
Gabriel Cruz
f345026900 fix(linters): use workaround for reviewdog bug (#1668) 2025-09-08 14:48:03 +00:00
Vlado Pajić
4780af2036 splitRPCMsg removed 2025-09-08 16:37:25 +02:00
vladopajic
5d6578a06f chore: splitRPCMsg improvements (#1665) 2025-09-08 11:06:55 -03:00
Gabriel Cruz
871a5d047f feat(autonat-v2): add server (#1658) 2025-09-04 13:27:49 -04:00
Gabriel Cruz
061195195b chore(autonat-v2): add utils (#1657) 2025-09-03 19:04:46 +00:00
Radosław Kamiński
8add5aaaab fix(rendezvous): peer registration limit (#1656) 2025-09-03 18:01:23 +01:00
Miran
dbf60b74c7 chore(ci): remove macos-13 from the matrix (#1650) 2025-09-03 11:16:37 -04:00
Radosław Kamiński
d2eaf07960 test(rendezvous): Registration TTL tests (#1655) 2025-09-02 15:43:48 +01:00
Gabriel Cruz
6e5274487e chore: pass localAddr in noise, mplex and yamux (#1654) 2025-09-01 23:38:23 +02:00
Gabriel Cruz
7ed62461d7 chore: add localAddr to Connection (#1651) 2025-09-01 20:39:08 +02:00
Radosław Kamiński
6059ee8332 test(performance): upload plots as artifacts (#1648) 2025-09-01 16:12:49 +00:00
Radosław Kamiński
4f7e232a9e fix(rendezvous): pagination offset (#1646) 2025-08-29 18:27:03 +01:00
richΛrd
5eaa43b860 fix: dont send GoAway for unknown streams and mark streams as closed on conn close (#1645) 2025-08-28 09:34:45 -04:00
richΛrd
17ed2d88df chore: temporarily disable performance plots from being published (#1647) 2025-08-28 08:20:12 -04:00
Radosław Kamiński
c7f29ed5db test(rendezvous): Refactor Rendezvous tests (#1644) 2025-08-28 09:35:04 +01:00
vladopajic
9865cc39b5 chore(perf): follow up for PR#1600 (#1620) 2025-08-26 10:00:25 -04:00
Gabriel Cruz
601f56b786 chore(autonat-v2): add message types (#1637) 2025-08-25 15:18:43 +00:00
Ben
25a8ed4d07 refactor(kad): Refine, and reduce, exception scope (#1627) 2025-08-25 11:33:26 +00:00
Radosław Kamiński
955e28ff70 test(yamux): Add unit tests - frame handling and stream initiation (#1634) 2025-08-22 12:02:54 +01:00
Radosław Kamiński
f952e6d436 test(performance): do not run publish steps on forks and fix cleanup (#1630) 2025-08-19 13:25:52 +01:00
MorganaFuture
bed83880bf fix(test): Race condition on Windows-specific daemon close (#1628)
Co-authored-by: Ben <benph@vac.dev>
Co-authored-by: vladopajic <vladopajic@users.noreply.github.com>
2025-08-18 17:09:31 -04:00
richΛrd
9bd4b7393f feat(kad-dht): findPeer (#1624) 2025-08-18 13:45:31 +00:00
Radosław Kamiński
12d1fae404 test(yamux): Add header unit tests (#1625) 2025-08-18 13:50:54 +01:00
MorganaFuture
17073dc9e0 fix(tests): prevent race condition in testgossipsubcontrolmessages (#1626) 2025-08-15 18:46:39 +00:00
vladopajic
b1649b3566 chore(quic): add length prefixed test (#1599) 2025-08-15 15:57:56 +02:00
Ben
ef20f46b47 refactor: rm dhttypes.nim (#1612) 2025-08-15 12:23:27 +00:00
Gabriel Cruz
9161529c84 fix: pubsub signature verification (#1618) 2025-08-14 20:15:02 +00:00
Ben
8b70384b6a refactor: Removal of "Unhashed" key variant (#1623)
Internal keydata is _always_ unhashed. The parts that require its data in hashed form hash it themselves using the provided hasher (with default fallback)
2025-08-14 11:22:09 +00:00
MorganaFuture
f25814a890 feat(perf): implement proper half-close semantics (#1600)
Co-authored-by: vladopajic <vladopajic@users.noreply.github.com>
2025-08-13 10:08:17 -04:00
Radosław Kamiński
3d5ea1fa3c test(performance): fetch before push and improve latency history (#1617) 2025-08-13 14:22:42 +01:00
richΛrd
2114008704 fix: compilation warning on yamux due to using CatchableErr (#1616) 2025-08-12 22:11:33 +00:00
richΛrd
04796b210b fix: don't check for errors as close() will only contain futures that raise [] (#1615) 2025-08-12 21:26:22 +00:00
Ben
59faa023aa feat(kad): Initial unstable putval api (#1582) 2025-08-12 12:25:21 +02:00
vladopajic
fdebea4e14 chore(quic): fix flaky test when eof is expected (#1611) 2025-08-11 17:02:13 +00:00
vladopajic
0c188df806 fix(quic): race errors when stopping transport (#1614) 2025-08-11 15:48:37 +00:00
Radosław Kamiński
abee5326dc test(gossipsub): Performance tests - plot latency history (#1608) 2025-08-11 16:11:29 +01:00
Radosław Kamiński
71f04d1bb3 test(gossipsub): Performance tests - plot docker stats (#1597) 2025-08-11 15:45:50 +01:00
Radosław Kamiński
41ae43ae80 test(gossipsub): Performance tests - collect docker stats (#1593) 2025-08-11 14:01:38 +00:00
vladopajic
5dbf077d9e chore(pubsub): simplify prune backoff test (#1596) 2025-08-09 17:49:14 +00:00
vladopajic
b5fc7582ff fix(quic): setting shortAgent (#1609) 2025-08-08 17:21:58 +00:00
vladopajic
7f83ebb198 chore(quic): readOnce better exception handling (#1610) 2025-08-08 16:02:33 +00:00
vladopajic
ceb89986c1 chore(quic): exception msg fix (#1607) 2025-08-08 10:24:55 -03:00
85 changed files with 3718 additions and 648 deletions

View File

@@ -1,12 +1,5 @@
name: Add Comment
description: "Add or update comment in the PR"
inputs:
marker:
description: "Text used to find the comment to update"
required: true
markdown_path:
description: "Path to the file containing markdown"
required: true
runs:
using: "composite"
@@ -16,8 +9,8 @@ runs:
with:
script: |
const fs = require('fs');
const marker = "${{ inputs.marker }}";
const body = fs.readFileSync("${{ inputs.markdown_path }}", 'utf8');
const marker = "${{ env.MARKER }}";
const body = fs.readFileSync("${{ env.COMMENT_SUMMARY_PATH }}", 'utf8');
const { data: comments } = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,

View File

@@ -0,0 +1,24 @@
name: Generate Plots
description: "Set up Python and run script to generate plots with Docker Stats"
runs:
using: "composite"
steps:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install Python dependencies
shell: bash
run: |
python -m pip install --upgrade pip
pip install matplotlib
- name: Plot Docker Stats
shell: bash
run: python performance/scripts/plot_docker_stats.py
- name: Plot Latency History
shell: bash
run: python performance/scripts/plot_latency_history.py

View File

@@ -0,0 +1,21 @@
name: Process Stats
description: "Set up Nim and run scripts to aggregate latency and process raw docker stats"
runs:
using: "composite"
steps:
- name: Set up Nim
uses: jiro4989/setup-nim-action@v2
with:
nim-version: "2.x"
repo-token: ${{ env.GITHUB_TOKEN }}
- name: Aggregate latency stats and prepare markdown for comment and summary
shell: bash
run: |
nim c -r -d:release -o:/tmp/process_latency_stats ./performance/scripts/process_latency_stats.nim
- name: Process raw docker stats to csv files
shell: bash
run: |
nim c -r -d:release -o:/tmp/process_docker_stats ./performance/scripts/process_docker_stats.nim

View File

@@ -0,0 +1,36 @@
name: Publish Latency History
description: "Publish latency history CSVs in a configurable branch and folder"
runs:
using: "composite"
steps:
- name: Clone the branch
uses: actions/checkout@v4
with:
repository: ${{ github.repository }}
ref: ${{ env.PUBLISH_BRANCH_NAME }}
path: ${{ env.CHECKOUT_SUBFOLDER_HISTORY }}
fetch-depth: 0
- name: Commit & push latency history CSVs
shell: bash
run: |
cd "$CHECKOUT_SUBFOLDER_HISTORY"
git fetch origin "$PUBLISH_BRANCH_NAME"
git reset --hard "origin/$PUBLISH_BRANCH_NAME"
mkdir -p "$PUBLISH_DIR_LATENCY_HISTORY"
cp ../$SHARED_VOLUME_PATH/$LATENCY_HISTORY_PREFIX*.csv "$PUBLISH_DIR_LATENCY_HISTORY/"
git add "$PUBLISH_DIR_LATENCY_HISTORY"
if git diff-index --quiet HEAD --; then
echo "No changes to commit"
else
git config user.email "github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
git commit -m "Update latency history CSVs"
git push origin "$PUBLISH_BRANCH_NAME"
fi
cd ..

View File

@@ -0,0 +1,56 @@
name: Publish Plots
description: "Publish plots in performance_plots branch and add to the workflow summary"
runs:
using: "composite"
steps:
- name: Clone the performance_plots branch
uses: actions/checkout@v4
with:
repository: ${{ github.repository }}
ref: ${{ env.PUBLISH_BRANCH_NAME }}
path: ${{ env.CHECKOUT_SUBFOLDER_SUBPLOTS }}
fetch-depth: 0
- name: Commit & push plots
shell: bash
run: |
cd $CHECKOUT_SUBFOLDER_SUBPLOTS
git fetch origin "$PUBLISH_BRANCH_NAME"
git reset --hard "origin/$PUBLISH_BRANCH_NAME"
# Remove any branch folder older than 7 days
DAYS=7
cutoff=$(( $(date +%s) - DAYS*24*3600 ))
scan_dir="${PUBLISH_DIR_PLOTS%/}"
find "$scan_dir" -mindepth 1 -maxdepth 1 -type d -print0 \
| while IFS= read -r -d $'\0' d; do \
ts=$(git log -1 --format=%ct -- "$d" 2>/dev/null || true); \
if [ -n "$ts" ] && [ "$ts" -le "$cutoff" ]; then \
echo "[cleanup] Deleting: $d"; \
rm -rf -- "$d"; \
fi; \
done
rm -rf $PUBLISH_DIR_PLOTS/$BRANCH_NAME
mkdir -p $PUBLISH_DIR_PLOTS/$BRANCH_NAME
cp ../$SHARED_VOLUME_PATH/*.png $PUBLISH_DIR_PLOTS/$BRANCH_NAME/ 2>/dev/null || true
cp ../$LATENCY_HISTORY_PATH/*.png $PUBLISH_DIR_PLOTS/ 2>/dev/null || true
git add -A "$PUBLISH_DIR_PLOTS/"
git status
if git diff-index --quiet HEAD --; then
echo "No changes to commit"
else
git config user.email "github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
git commit -m "Update performance plots for $BRANCH_NAME"
git push origin $PUBLISH_BRANCH_NAME
fi
- name: Add plots to GitHub Actions summary
shell: bash
run: |
nim c -r -d:release -o:/tmp/add_plots_to_summary ./performance/scripts/add_plots_to_summary.nim

View File

@@ -25,8 +25,6 @@ jobs:
cpu: i386
- os: linux-gcc-14
cpu: amd64
- os: macos
cpu: amd64
- os: macos-14
cpu: arm64
- os: windows
@@ -45,10 +43,6 @@ jobs:
os: linux-gcc-14
builder: ubuntu-24.04
shell: bash
- platform:
os: macos
builder: macos-13
shell: bash
- platform:
os: macos-14
builder: macos-14

View File

@@ -19,7 +19,7 @@ jobs:
fetch-depth: 2 # In PR, has extra merge commit: ^1 = PR, ^2 = base
- name: Check `nph` formatting
uses: arnetheduck/nph-action@v1
uses: arnetheduck/nph-action@ef5e9fae6dbaf88ec4308cbede780a0ba45f845d
with:
version: 0.6.1
options: "examples libp2p tests interop tools *.nim*"

View File

@@ -13,8 +13,8 @@ concurrency:
cancel-in-progress: true
jobs:
examples:
timeout-minutes: 10
performance:
timeout-minutes: 20
strategy:
fail-fast: false
@@ -22,6 +22,25 @@ jobs:
run:
shell: bash
env:
VACP2P: "vacp2p"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
PR_NUMBER: ${{ github.event.number }}
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
MARKER: "<!-- perf-summary-marker -->"
COMMENT_SUMMARY_PATH: "/tmp/perf-summary.md"
SHARED_VOLUME_PATH: "performance/output"
DOCKER_STATS_PREFIX: "docker_stats_"
PUBLISH_BRANCH_NAME: "performance_plots"
CHECKOUT_SUBFOLDER_SUBPLOTS: "subplots"
PUBLISH_DIR_PLOTS: "plots"
CHECKOUT_SUBFOLDER_HISTORY: "history"
PUBLISH_DIR_LATENCY_HISTORY: "latency_history"
LATENCY_HISTORY_PATH: "history/latency_history"
LATENCY_HISTORY_PREFIX: "pr"
LATENCY_HISTORY_PLOT_FILENAME: "latency_history_all_scenarios.png"
name: "Performance"
runs-on: ubuntu-22.04
steps:
@@ -47,23 +66,29 @@ jobs:
run: |
./performance/runner.sh
- name: Set up Nim for aggragate script
uses: jiro4989/setup-nim-action@v2
with:
nim-version: "2.x"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Process latency and docker stats
uses: ./.github/actions/process_stats
- name: Aggregate and display summary
env:
MARKER: "<!-- perf-summary-marker -->"
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
COMMENT_SUMMARY_PATH: "/tmp/perf-summary.md"
run: |
nim c -r -d:release -o:/tmp/aggregate_stats ./performance/aggregate_stats.nim
- name: Publish history
if: github.repository_owner == env.VACP2P
uses: ./.github/actions/publish_history
- name: Post/Update PR Performance Comment
- name: Generate plots
if: github.repository_owner == env.VACP2P
uses: ./.github/actions/generate_plots
- name: Post/Update PR comment
if: github.event_name == 'pull_request'
uses: ./.github/actions/add_comment
- name: Upload performance artifacts
if: success() || failure()
uses: actions/upload-artifact@v4
with:
marker: "<!-- perf-summary-marker -->"
markdown_path: "/tmp/perf-summary.md"
name: performance-artifacts
path: |
performance/output/pr*_latency.csv
performance/output/*.png
history/latency_history/*.png
if-no-files-found: ignore
retention-days: 7

View File

@@ -8,7 +8,7 @@ json_serialization;https://github.com/status-im/nim-json-serialization@#2b1c5eb1
metrics;https://github.com/status-im/nim-metrics@#6142e433fc8ea9b73379770a788017ac528d46ff
ngtcp2;https://github.com/status-im/nim-ngtcp2@#9456daa178c655bccd4a3c78ad3b8cce1f0add73
nimcrypto;https://github.com/cheatfate/nimcrypto@#19c41d6be4c00b4a2c8000583bd30cf8ceb5f4b1
quic;https://github.com/status-im/nim-quic.git@#d9a4cbccd509f7a3ee835f75b01dec29d27a0f14
quic;https://github.com/vacp2p/nim-quic@#cae13c2d22ba2730c979486cf89b88927045c3ae
results;https://github.com/arnetheduck/nim-results@#df8113dda4c2d74d460a8fa98252b0b771bf1f27
secp256k1;https://github.com/status-im/nim-secp256k1@#f808ed5e7a7bfc42204ec7830f14b7a42b63c284
serialization;https://github.com/status-im/nim-serialization@#548d0adc9797a10b2db7f788b804330306293088

View File

@@ -10,7 +10,7 @@ skipDirs = @["tests", "examples", "Nim", "tools", "scripts", "docs"]
requires "nim >= 2.0.0",
"nimcrypto >= 0.6.0 & < 0.7.0", "dnsclient >= 0.3.0 & < 0.4.0", "bearssl >= 0.2.5",
"chronicles >= 0.11.0 & < 0.12.0", "chronos >= 4.0.4", "metrics", "secp256k1",
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.10",
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.15",
"https://github.com/vacp2p/nim-jwt.git#18f8378de52b241f321c1f9ea905456e89b95c6f"
let nimc = getEnv("NIMC", "nim") # Which nim compiler to use

View File

@@ -85,6 +85,7 @@ when defined(libp2p_autotls_support):
../crypto/rsa,
../utils/heartbeat,
../transports/transport,
../utils/ipaddr,
../transports/tcptransport,
../nameresolving/dnsresolver
@@ -150,7 +151,10 @@ when defined(libp2p_autotls_support):
if self.config.ipAddress.isNone():
try:
self.config.ipAddress = Opt.some(getPublicIPAddress())
except AutoTLSError as exc:
except ValueError as exc:
error "Failed to get public IP address", err = exc.msg
return false
except OSError as exc:
error "Failed to get public IP address", err = exc.msg
return false
self.managerFut = self.run(switch)

View File

@@ -22,7 +22,7 @@ const
type AutoTLSError* = object of LPError
when defined(libp2p_autotls_support):
import net, strutils
import strutils
from times import DateTime, toTime, toUnix
import stew/base36
import
@@ -33,36 +33,6 @@ when defined(libp2p_autotls_support):
../nameresolving/nameresolver,
./acme/client
proc checkedGetPrimaryIPAddr*(): IpAddress {.raises: [AutoTLSError].} =
# This is so that we don't need to catch Exceptions directly
# since we support 1.6.16 and getPrimaryIPAddr before nim 2 didn't have explicit .raises. pragmas
try:
return getPrimaryIPAddr()
except Exception as exc:
raise newException(AutoTLSError, "Error while getting primary IP address", exc)
proc isIPv4*(ip: IpAddress): bool =
ip.family == IpAddressFamily.IPv4
proc isPublic*(ip: IpAddress): bool {.raises: [AutoTLSError].} =
let ip = $ip
try:
not (
ip.startsWith("10.") or
(ip.startsWith("172.") and parseInt(ip.split(".")[1]) in 16 .. 31) or
ip.startsWith("192.168.") or ip.startsWith("127.") or ip.startsWith("169.254.")
)
except ValueError as exc:
raise newException(AutoTLSError, "Failed to parse IP address", exc)
proc getPublicIPAddress*(): IpAddress {.raises: [AutoTLSError].} =
let ip = checkedGetPrimaryIPAddr()
if not ip.isIPv4():
raise newException(AutoTLSError, "Host does not have an IPv4 address")
if not ip.isPublic():
raise newException(AutoTLSError, "Host does not have a public IPv4 address")
return ip
proc asMoment*(dt: DateTime): Moment =
let unixTime: int64 = dt.toTime.toUnix
return Moment.init(unixTime, Second)

View File

@@ -26,7 +26,8 @@ import
transports/[transport, tcptransport, wstransport, memorytransport],
muxers/[muxer, mplex/mplex, yamux/yamux],
protocols/[identify, secure/secure, secure/noise, rendezvous],
protocols/connectivity/[autonat/server, relay/relay, relay/client, relay/rtransport],
protocols/connectivity/
[autonat/server, autonatv2/server, relay/relay, relay/client, relay/rtransport],
connmanager,
upgrademngrs/muxedupgrade,
observedaddrmanager,
@@ -74,6 +75,7 @@ type
nameResolver: NameResolver
peerStoreCapacity: Opt[int]
autonat: bool
autonatV2: bool
autotls: AutotlsService
circuitRelay: Relay
rdv: RendezVous
@@ -280,6 +282,10 @@ proc withAutonat*(b: SwitchBuilder): SwitchBuilder =
b.autonat = true
b
proc withAutonatV2*(b: SwitchBuilder): SwitchBuilder =
b.autonatV2 = true
b
when defined(libp2p_autotls_support):
proc withAutotls*(
b: SwitchBuilder, config: AutotlsConfig = AutotlsConfig.new()
@@ -379,7 +385,10 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
switch.mount(identify)
if b.autonat:
if b.autonatV2:
let autonatV2 = AutonatV2.new(switch)
switch.mount(autonatV2)
elif b.autonat:
let autonat = Autonat.new(switch)
switch.mount(autonat)
@@ -395,7 +404,7 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
return switch
proc newStandardSwitch*(
proc newStandardSwitchBuilder*(
privKey = none(PrivateKey),
addrs: MultiAddress | seq[MultiAddress] =
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
@@ -411,7 +420,7 @@ proc newStandardSwitch*(
nameResolver: NameResolver = nil,
sendSignedPeerRecord = false,
peerStoreCapacity = 1000,
): Switch {.raises: [LPError], public.} =
): SwitchBuilder {.raises: [LPError], public.} =
## Helper for common switch configurations.
let addrs =
when addrs is MultiAddress:
@@ -436,4 +445,39 @@ proc newStandardSwitch*(
privKey.withValue(pkey):
b = b.withPrivateKey(pkey)
b.build()
b
proc newStandardSwitch*(
privKey = none(PrivateKey),
addrs: MultiAddress | seq[MultiAddress] =
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
transportFlags: set[ServerFlags] = {},
rng = newRng(),
inTimeout: Duration = 5.minutes,
outTimeout: Duration = 5.minutes,
maxConnections = MaxConnections,
maxIn = -1,
maxOut = -1,
maxConnsPerPeer = MaxConnectionsPerPeer,
nameResolver: NameResolver = nil,
sendSignedPeerRecord = false,
peerStoreCapacity = 1000,
): Switch {.raises: [LPError], public.} =
newStandardSwitchBuilder(
privKey = privKey,
addrs = addrs,
secureManagers = secureManagers,
transportFlags = transportFlags,
rng = rng,
inTimeout = inTimeout,
outTimeout = outTimeout,
maxConnections = maxConnections,
maxIn = maxIn,
maxOut = maxOut,
maxConnsPerPeer = maxConnsPerPeer,
nameResolver = nameResolver,
sendSignedPeerRecord = sendSignedPeerRecord,
peerStoreCapacity = peerStoreCapacity,
)
.build()

View File

@@ -11,7 +11,7 @@
import chronos
import results
import peerid, stream/connection, transports/transport
import peerid, stream/connection, transports/transport, muxers/muxer
export results
@@ -65,6 +65,23 @@ method dial*(
method addTransport*(self: Dial, transport: Transport) {.base.} =
doAssert(false, "[Dial.addTransport] abstract method not implemented!")
method dialAndUpgrade*(
self: Dial, peerId: Opt[PeerId], address: MultiAddress, dir = Direction.Out
): Future[Muxer] {.base, async: (raises: [CancelledError]).} =
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
method dialAndUpgrade*(
self: Dial, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
): Future[Muxer] {.
base, async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
.} =
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
method negotiateStream*(
self: Dial, conn: Connection, protos: seq[string]
): Future[Connection] {.base, async: (raises: [CatchableError]).} =
doAssert(false, "[Dial.negotiateStream] abstract method not implemented!")
method tryDial*(
self: Dial, peerId: PeerId, addrs: seq[MultiAddress]
): Future[Opt[MultiAddress]] {.

View File

@@ -43,7 +43,7 @@ type Dialer* = ref object of Dial
peerStore: PeerStore
nameResolver: NameResolver
proc dialAndUpgrade(
proc dialAndUpgrade*(
self: Dialer,
peerId: Opt[PeerId],
hostname: string,
@@ -139,7 +139,7 @@ proc expandDnsAddr(
else:
result.add((resolvedAddress, peerId))
proc dialAndUpgrade(
proc dialAndUpgrade*(
self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
): Future[Muxer] {.
async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
@@ -284,7 +284,7 @@ method connect*(
return
(await self.internalConnect(Opt.none(PeerId), @[address], false)).connection.peerId
proc negotiateStream(
proc negotiateStream*(
self: Dialer, conn: Connection, protos: seq[string]
): Future[Connection] {.async: (raises: [CatchableError]).} =
trace "Negotiating stream", conn, protos

View File

@@ -27,7 +27,7 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
quote:
for res in `futs`:
if res.failed:
let exc = res.readError()
let exc = res.error
# We still don't abort but warn
debug "A future has failed, enable trace logging for details",
error = exc.name
@@ -37,7 +37,7 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
for res in `futs`:
block check:
if res.failed:
let exc = res.readError()
let exc = res.error
for i in 0 ..< `nexclude`:
if exc of `exclude`[i]:
trace "A future has failed", error = exc.name, description = exc.msg

View File

@@ -843,6 +843,14 @@ proc init*(
res.data.finish()
ok(res)
proc getPart*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] =
## Returns the first multiaddress in ``value`` with codec ``codec``
for part in ma:
let part = ?part
if codec == ?part.protoCode:
return ok(part)
err("no such codec in multiaddress")
proc getProtocol(name: string): MAProtocol {.inline.} =
let mc = MultiCodec.codec(name)
if mc != InvalidMultiCodec:
@@ -1119,3 +1127,32 @@ proc getRepeatedField*(
err(ProtoError.IncorrectBlob)
else:
ok(true)
proc areAddrsConsistent*(a, b: MultiAddress): bool =
## Checks if two multiaddresses have the same protocol stack.
let protosA = a.protocols().get()
let protosB = b.protocols().get()
if protosA.len != protosB.len:
return false
for idx in 0 ..< protosA.len:
let protoA = protosA[idx]
let protoB = protosB[idx]
if protoA != protoB:
if idx == 0:
# allow DNS ↔ IP at the first component
if protoB == multiCodec("dns") or protoB == multiCodec("dnsaddr"):
if not (protoA == multiCodec("ip4") or protoA == multiCodec("ip6")):
return false
elif protoB == multiCodec("dns4"):
if protoA != multiCodec("ip4"):
return false
elif protoB == multiCodec("dns6"):
if protoA != multiCodec("ip6"):
return false
else:
return false
else:
return false
true

View File

@@ -150,6 +150,10 @@ method close*(s: LPChannel) {.async: (raises: []).} =
trace "Closed channel", s, len = s.len
method closeWrite*(s: LPChannel) {.async: (raises: []).} =
## For mplex, closeWrite is the same as close - it implements half-close
await s.close()
method initStream*(s: LPChannel) =
if s.objName.len == 0:
s.objName = LPChannelTrackerName

View File

@@ -95,6 +95,7 @@ proc newStreamInternal*(
result.peerId = m.connection.peerId
result.observedAddr = m.connection.observedAddr
result.localAddr = m.connection.localAddr
result.transportDir = m.connection.transportDir
when defined(libp2p_agents_metrics):
result.shortAgent = m.connection.shortAgent

View File

@@ -54,6 +54,10 @@ method newStream*(
.} =
raiseAssert("[Muxer.newStream] abstract method not implemented!")
when defined(libp2p_agents_metrics):
method setShortAgent*(m: Muxer, shortAgent: string) {.base, gcsafe.} =
m.connection.shortAgent = shortAgent
method close*(m: Muxer) {.base, async: (raises: []).} =
if m.connection != nil:
await m.connection.close()

View File

@@ -217,7 +217,11 @@ method closeImpl*(channel: YamuxChannel) {.async: (raises: []).} =
discard
await channel.actuallyClose()
proc clearQueues(channel: YamuxChannel, error: ref CatchableError = nil) =
method closeWrite*(channel: YamuxChannel) {.async: (raises: []).} =
## For yamux, closeWrite is the same as close - it implements half-close
await channel.close()
proc clearQueues(channel: YamuxChannel, error: ref LPStreamEOFError = nil) =
for toSend in channel.sendQueue:
if error.isNil():
toSend.fut.complete()
@@ -511,6 +515,7 @@ proc createStream(
stream.initStream()
stream.peerId = m.connection.peerId
stream.observedAddr = m.connection.observedAddr
stream.localAddr = m.connection.localAddr
stream.transportDir = m.connection.transportDir
when defined(libp2p_agents_metrics):
stream.shortAgent = m.connection.shortAgent
@@ -529,13 +534,13 @@ method close*(m: Yamux) {.async: (raises: []).} =
trace "Closing yamux"
let channels = toSeq(m.channels.values())
for channel in channels:
for toSend in channel.sendQueue:
toSend.fut.fail(newLPStreamEOFError())
channel.sendQueue = @[]
channel.clearQueues(newLPStreamEOFError())
channel.recvWindow = 0
channel.sendWindow = 0
channel.closedLocally = true
channel.isReset = true
channel.opened = false
channel.isClosed = true
await channel.remoteClosed()
channel.receivedData.fire()
try:
@@ -607,8 +612,10 @@ method handle*(m: Yamux) {.async: (raises: []).} =
if header.length > 0:
var buffer = newSeqUninit[byte](header.length)
await m.connection.readExactly(addr buffer[0], int(header.length))
do:
raise newException(YamuxError, "Unknown stream ID: " & $header.streamId)
# If we do not have a stream, likely we sent a RST and/or closed the stream
trace "unknown stream id", id = header.streamId
continue
let channel =

View File

@@ -214,7 +214,7 @@ proc identify*(
info.agentVersion.get("").split("/")[0].safeToLowerAscii().get("")
if KnownLibP2PAgentsSeq.contains(shortAgent):
knownAgent = shortAgent
muxer.connection.setShortAgent(knownAgent)
muxer.setShortAgent(knownAgent)
peerStore.updatePeerInfo(info, stream.observedAddr)
finally:

View File

@@ -12,7 +12,7 @@
import results
import chronos, chronicles
import ../../../switch, ../../../multiaddress, ../../../peerid
import core
import types
logScope:
topics = "libp2p autonat"

View File

@@ -20,9 +20,9 @@ import
../../../peerid,
../../../utils/[semaphore, future],
../../../errors
import core
import types
export core
export types
logScope:
topics = "libp2p autonat"

View File

@@ -14,11 +14,11 @@ import chronos, metrics
import ../../../switch
import ../../../wire
import client
from core import NetworkReachability, AutonatUnreachableError
from types import NetworkReachability, AutonatUnreachableError
import ../../../utils/heartbeat
import ../../../crypto/crypto
export core.NetworkReachability
export NetworkReachability
logScope:
topics = "libp2p autonatservice"

View File

@@ -0,0 +1,276 @@
# Nim-LibP2P
# Copyright (c) 2025 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import results
import chronos, chronicles
import
../../../../libp2p/[
switch,
muxers/muxer,
dialer,
multiaddress,
transports/transport,
multicodec,
peerid,
protobuf/minprotobuf,
utils/ipaddr,
],
../../protocol,
./types
logScope:
topics = "libp2p autonat v2 server"
const
DefaultDialTimeout: Duration = 15.seconds
DefaultAmplificationAttackDialTimeout: Duration = 3.seconds
DefaultDialDataSize: uint64 = 50 * 1024 # 50 KiB > 50 KB
AutonatV2MsgLpSize: int = 1024
# readLp needs to receive more than 4096 bytes (since it's a DialDataResponse) + overhead
AutonatV2DialDataResponseLpSize: int = 5000
type AutonatV2* = ref object of LPProtocol
switch*: Switch
dialTimeout: Duration
dialDataSize: uint64
amplificationAttackTimeout: Duration
proc sendDialResponse(
conn: Connection,
status: ResponseStatus,
addrIdx: Opt[AddrIdx] = Opt.none(AddrIdx),
dialStatus: Opt[DialStatus] = Opt.none(DialStatus),
) {.async: (raises: [CancelledError, LPStreamError]).} =
await conn.writeLp(
AutonatV2Msg(
msgType: MsgType.DialResponse,
dialResp: DialResponse(status: status, addrIdx: addrIdx, dialStatus: dialStatus),
).encode().buffer
)
proc findObservedIPAddr*(
conn: Connection, req: DialRequest
): Future[Opt[MultiAddress]] {.async: (raises: [CancelledError, LPStreamError]).} =
let observedAddr = conn.observedAddr.valueOr:
await conn.sendDialResponse(ResponseStatus.EInternalError)
return Opt.none(MultiAddress)
let isRelayed = observedAddr.contains(multiCodec("p2p-circuit")).valueOr:
error "Invalid observed address"
await conn.sendDialResponse(ResponseStatus.EDialRefused)
return Opt.none(MultiAddress)
if isRelayed:
error "Invalid observed address: relayed address"
await conn.sendDialResponse(ResponseStatus.EDialRefused)
return Opt.none(MultiAddress)
let hostIp = observedAddr[0].valueOr:
error "Invalid observed address"
await conn.sendDialResponse(ResponseStatus.EInternalError)
return Opt.none(MultiAddress)
return Opt.some(hostIp)
proc dialBack(
conn: Connection, nonce: Nonce
): Future[DialStatus] {.
async: (raises: [CancelledError, DialFailedError, LPStreamError])
.} =
try:
# send dial back
await conn.writeLp(DialBack(nonce: nonce).encode().buffer)
# receive DialBackResponse
let dialBackResp = DialBackResponse.decode(
initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))
).valueOr:
error "DialBack failed, could not decode DialBackResponse"
return DialStatus.EDialBackError
except LPStreamRemoteClosedError as exc:
# failed because of nonce error (remote reset the stream): EDialBackError
error "DialBack failed, remote closed the connection", description = exc.msg
return DialStatus.EDialBackError
# TODO: failed because of client or server resources: EDialError
trace "DialBack successful"
return DialStatus.Ok
proc handleDialDataResponses(
self: AutonatV2, conn: Connection
) {.async: (raises: [CancelledError, AutonatV2Error, LPStreamError]).} =
var dataReceived: uint64 = 0
while dataReceived < self.dialDataSize:
let msg = AutonatV2Msg.decode(
initProtoBuffer(await conn.readLp(AutonatV2DialDataResponseLpSize))
).valueOr:
raise newException(AutonatV2Error, "Received malformed message")
debug "Received message", msgType = $msg.msgType
if msg.msgType != MsgType.DialDataResponse:
raise
newException(AutonatV2Error, "Expecting DialDataResponse, got " & $msg.msgType)
let resp = msg.dialDataResp
dataReceived += resp.data.len.uint64
debug "received data",
dataReceived = resp.data.len.uint64, totalDataReceived = dataReceived
proc amplificationAttackPrevention(
self: AutonatV2, conn: Connection, addrIdx: AddrIdx
): Future[bool] {.async: (raises: [CancelledError, LPStreamError]).} =
# send DialDataRequest
await conn.writeLp(
AutonatV2Msg(
msgType: MsgType.DialDataRequest,
dialDataReq: DialDataRequest(addrIdx: addrIdx, numBytes: self.dialDataSize),
).encode().buffer
)
# recieve DialDataResponses until we're satisfied
try:
if not await self.handleDialDataResponses(conn).withTimeout(self.dialTimeout):
error "Amplification attack prevention timeout",
timeout = self.amplificationAttackTimeout, peer = conn.peerId
return false
except AutonatV2Error as exc:
error "Amplification attack prevention failed", description = exc.msg
return false
return true
proc canDial(self: AutonatV2, addrs: MultiAddress): bool =
let (ipv4Support, ipv6Support) = self.switch.peerInfo.listenAddrs.ipSupport()
addrs[0].withValue(addrIp):
if IP4.match(addrIp) and not ipv4Support:
return false
if IP6.match(addrIp) and not ipv6Support:
return false
try:
if isPrivate($addrIp):
return false
except ValueError:
warn "Unable to parse IP address, skipping", addrs = $addrIp
return false
for t in self.switch.transports:
if t.handles(addrs):
return true
return false
proc forceNewConnection(
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
): Future[Opt[Connection]] {.async: (raises: [CancelledError]).} =
## Bypasses connManager to force a new connection to ``pid``
## instead of reusing a preexistent one
try:
let mux = await self.switch.dialer.dialAndUpgrade(Opt.some(pid), addrs)
if mux.isNil():
return Opt.none(Connection)
return Opt.some(
await self.switch.negotiateStream(
await mux.newStream(), @[$AutonatV2Codec.DialBack]
)
)
except CancelledError as exc:
raise exc
except CatchableError:
return Opt.none(Connection)
proc chooseDialAddr(
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
): Future[Opt[(Connection, AddrIdx)]] {.async: (raises: [CancelledError]).} =
for i, ma in addrs:
if self.canDial(ma):
debug "Trying to dial", chosenAddrs = ma, addrIdx = i
let conn = (await self.forceNewConnection(pid, @[ma])).valueOr:
return Opt.none((Connection, AddrIdx))
return Opt.some((conn, i.AddrIdx))
return Opt.none((Connection, AddrIdx))
proc handleDialRequest(
self: AutonatV2, conn: Connection, req: DialRequest
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
let observedIPAddr = (await conn.findObservedIPAddr(req)).valueOr:
error "Could not find observed IP address"
return
let (dialBackConn, addrIdx) = (await self.chooseDialAddr(conn.peerId, req.addrs)).valueOr:
error "No dialable addresses found"
await conn.sendDialResponse(ResponseStatus.EDialRefused)
return
defer:
await dialBackConn.close()
# if observed address for peer is not in address list to try
# then we perform Amplification Attack Prevention
if not ipAddrMatches(observedIPAddr, req.addrs):
debug "Starting amplification attack prevention",
observedIPAddr = observedIPAddr, testAddr = req.addrs[addrIdx]
# send DialDataRequest and wait until dataReceived is enough
if not await self.amplificationAttackPrevention(conn, addrIdx):
return
debug "Sending DialBack",
nonce = req.nonce, addrIdx = addrIdx, addr = req.addrs[addrIdx]
let dialStatus = await dialBackConn.dialBack(req.nonce)
await conn.sendDialResponse(
ResponseStatus.Ok, addrIdx = Opt.some(addrIdx), dialStatus = Opt.some(dialStatus)
)
proc new*(
T: typedesc[AutonatV2],
switch: Switch,
dialTimeout: Duration = DefaultDialTimeout,
dialDataSize: uint64 = DefaultDialDataSize,
amplificationAttackTimeout: Duration = DefaultAmplificationAttackDialTimeout,
): T =
let autonatV2 = T(
switch: switch,
dialTimeout: dialTimeout,
dialDataSize: dialDataSize,
amplificationAttackTimeout: amplificationAttackTimeout,
)
proc handleStream(
conn: Connection, proto: string
) {.async: (raises: [CancelledError]).} =
defer:
await conn.close()
let msg =
try:
AutonatV2Msg.decode(initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))).valueOr:
trace "Unable to decode AutonatV2Msg"
return
except LPStreamError as exc:
debug "Could not receive AutonatV2Msg", description = exc.msg
return
debug "Received message", msgType = $msg.msgType
if msg.msgType != MsgType.DialRequest:
debug "Expecting DialRequest", receivedMsgType = msg.msgType
return
try:
await autonatV2.handleDialRequest(conn, msg.dialReq)
except CancelledError as exc:
raise exc
except LPStreamRemoteClosedError as exc:
debug "Connection closed by peer", description = exc.msg, peer = conn.peerId
except LPStreamError as exc:
debug "Stream Error", description = exc.msg
except DialFailedError as exc:
debug "Could not dial peer", description = exc.msg, peer = conn.peerId
autonatV2.handler = handleStream
autonatV2.codec = $AutonatV2Codec.DialRequest
autonatV2

View File

@@ -0,0 +1,254 @@
# Nim-LibP2P
# Copyright (c) 2025 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import results, chronos, chronicles
import
../../../multiaddress, ../../../peerid, ../../../protobuf/minprotobuf, ../../../switch
from ../autonat/types import NetworkReachability
type
AutonatV2Codec* {.pure.} = enum
DialRequest = "/libp2p/autonat/2/dial-request"
DialBack = "/libp2p/autonat/2/dial-back"
AutonatV2Response* = object
reachability*: NetworkReachability
dialResp*: DialResponse
addrs*: Opt[MultiAddress]
AutonatV2Error* = object of LPError
Nonce* = uint64
AddrIdx* = uint32
NumBytes* = uint64
MsgType* {.pure.} = enum
# DialBack and DialBackResponse are not defined as AutonatV2Msg as per the spec
# likely because they are expected in response to some other message
DialRequest
DialResponse
DialDataRequest
DialDataResponse
ResponseStatus* {.pure.} = enum
EInternalError = 0
ERequestRejected = 100
EDialRefused = 101
Ok = 200
DialBackStatus* {.pure.} = enum
Ok = 0
DialStatus* {.pure.} = enum
Unused = 0
EDialError = 100
EDialBackError = 101
Ok = 200
DialRequest* = object
addrs*: seq[MultiAddress]
nonce*: Nonce
DialResponse* = object
status*: ResponseStatus
addrIdx*: Opt[AddrIdx]
dialStatus*: Opt[DialStatus]
DialBack* = object
nonce*: Nonce
DialBackResponse* = object
status*: DialBackStatus
DialDataRequest* = object
addrIdx*: AddrIdx
numBytes*: NumBytes
DialDataResponse* = object
data*: seq[byte]
AutonatV2Msg* = object
case msgType*: MsgType
of MsgType.DialRequest:
dialReq*: DialRequest
of MsgType.DialResponse:
dialResp*: DialResponse
of MsgType.DialDataRequest:
dialDataReq*: DialDataRequest
of MsgType.DialDataResponse:
dialDataResp*: DialDataResponse
# DialRequest
proc encode*(dialReq: DialRequest): ProtoBuffer =
var encoded = initProtoBuffer()
for ma in dialReq.addrs:
encoded.write(1, ma.data.buffer)
encoded.write(2, dialReq.nonce)
encoded.finish()
encoded
proc decode*(T: typedesc[DialRequest], pb: ProtoBuffer): Opt[T] =
var
addrs: seq[MultiAddress]
nonce: Nonce
if not ?pb.getRepeatedField(1, addrs).toOpt():
return Opt.none(T)
if not ?pb.getField(2, nonce).toOpt():
return Opt.none(T)
Opt.some(T(addrs: addrs, nonce: nonce))
# DialResponse
proc encode*(dialResp: DialResponse): ProtoBuffer =
var encoded = initProtoBuffer()
encoded.write(1, dialResp.status.uint)
# minprotobuf casts uses float64 for fixed64 fields
dialResp.addrIdx.withValue(addrIdx):
encoded.write(2, addrIdx)
dialResp.dialStatus.withValue(dialStatus):
encoded.write(3, dialStatus.uint)
encoded.finish()
encoded
proc decode*(T: typedesc[DialResponse], pb: ProtoBuffer): Opt[T] =
var
status: uint
addrIdx: AddrIdx
dialStatus: uint
if not ?pb.getField(1, status).toOpt():
return Opt.none(T)
var optAddrIdx = Opt.none(AddrIdx)
if ?pb.getField(2, addrIdx).toOpt():
optAddrIdx = Opt.some(addrIdx)
var optDialStatus = Opt.none(DialStatus)
if ?pb.getField(3, dialStatus).toOpt():
optDialStatus = Opt.some(cast[DialStatus](dialStatus))
Opt.some(
T(
status: cast[ResponseStatus](status),
addrIdx: optAddrIdx,
dialStatus: optDialStatus,
)
)
# DialBack
proc encode*(dialBack: DialBack): ProtoBuffer =
var encoded = initProtoBuffer()
encoded.write(1, dialBack.nonce)
encoded.finish()
encoded
proc decode*(T: typedesc[DialBack], pb: ProtoBuffer): Opt[T] =
var nonce: Nonce
if not ?pb.getField(1, nonce).toOpt():
return Opt.none(T)
Opt.some(T(nonce: nonce))
# DialBackResponse
proc encode*(dialBackResp: DialBackResponse): ProtoBuffer =
var encoded = initProtoBuffer()
encoded.write(1, dialBackResp.status.uint)
encoded.finish()
encoded
proc decode*(T: typedesc[DialBackResponse], pb: ProtoBuffer): Opt[T] =
var status: uint
if not ?pb.getField(1, status).toOpt():
return Opt.none(T)
Opt.some(T(status: cast[DialBackStatus](status)))
# DialDataRequest
proc encode*(dialDataReq: DialDataRequest): ProtoBuffer =
var encoded = initProtoBuffer()
encoded.write(1, dialDataReq.addrIdx)
encoded.write(2, dialDataReq.numBytes)
encoded.finish()
encoded
proc decode*(T: typedesc[DialDataRequest], pb: ProtoBuffer): Opt[T] =
var
addrIdx: AddrIdx
numBytes: NumBytes
if not ?pb.getField(1, addrIdx).toOpt():
return Opt.none(T)
if not ?pb.getField(2, numBytes).toOpt():
return Opt.none(T)
Opt.some(T(addrIdx: addrIdx, numBytes: numBytes))
# DialDataResponse
proc encode*(dialDataResp: DialDataResponse): ProtoBuffer =
var encoded = initProtoBuffer()
encoded.write(1, dialDataResp.data)
encoded.finish()
encoded
proc decode*(T: typedesc[DialDataResponse], pb: ProtoBuffer): Opt[T] =
var data: seq[byte]
if not ?pb.getField(1, data).toOpt():
return Opt.none(T)
Opt.some(T(data: data))
proc protoField(msgType: MsgType): int =
case msgType
of MsgType.DialRequest: 1.int
of MsgType.DialResponse: 2.int
of MsgType.DialDataRequest: 3.int
of MsgType.DialDataResponse: 4.int
# AutonatV2Msg
proc encode*(msg: AutonatV2Msg): ProtoBuffer =
var encoded = initProtoBuffer()
case msg.msgType
of MsgType.DialRequest:
encoded.write(MsgType.DialRequest.protoField, msg.dialReq.encode())
of MsgType.DialResponse:
encoded.write(MsgType.DialResponse.protoField, msg.dialResp.encode())
of MsgType.DialDataRequest:
encoded.write(MsgType.DialDataRequest.protoField, msg.dialDataReq.encode())
of MsgType.DialDataResponse:
encoded.write(MsgType.DialDataResponse.protoField, msg.dialDataResp.encode())
encoded.finish()
encoded
proc decode*(T: typedesc[AutonatV2Msg], pb: ProtoBuffer): Opt[T] =
var
msgTypeOrd: uint32
msg: ProtoBuffer
if ?pb.getField(MsgType.DialRequest.protoField, msg).toOpt():
let dialReq = DialRequest.decode(msg).valueOr:
return Opt.none(AutonatV2Msg)
Opt.some(AutonatV2Msg(msgType: MsgType.DialRequest, dialReq: dialReq))
elif ?pb.getField(MsgType.DialResponse.protoField, msg).toOpt():
let dialResp = DialResponse.decode(msg).valueOr:
return Opt.none(AutonatV2Msg)
Opt.some(AutonatV2Msg(msgType: MsgType.DialResponse, dialResp: dialResp))
elif ?pb.getField(MsgType.DialDataRequest.protoField, msg).toOpt():
let dialDataReq = DialDataRequest.decode(msg).valueOr:
return Opt.none(AutonatV2Msg)
Opt.some(AutonatV2Msg(msgType: MsgType.DialDataRequest, dialDataReq: dialDataReq))
elif ?pb.getField(MsgType.DialDataResponse.protoField, msg).toOpt():
let dialDataResp = DialDataResponse.decode(msg).valueOr:
return Opt.none(AutonatV2Msg)
Opt.some(
AutonatV2Msg(msgType: MsgType.DialDataResponse, dialDataResp: dialDataResp)
)
else:
Opt.none(AutonatV2Msg)
# Custom `==` is needed to compare since AutonatV2Msg is a case object
proc `==`*(a, b: AutonatV2Msg): bool =
a.msgType == b.msgType and a.encode() == b.encode()

View File

@@ -0,0 +1,47 @@
{.push raises: [].}
import results
import chronos
import
../../protocol,
../../../switch,
../../../multiaddress,
../../../multicodec,
../../../peerid,
../../../protobuf/minprotobuf,
../autonat/service,
./types
proc asNetworkReachability*(self: DialResponse): NetworkReachability =
if self.status == EInternalError:
return Unknown
if self.status == ERequestRejected:
return Unknown
if self.status == EDialRefused:
return Unknown
# if got here it means a dial was attempted
let dialStatus = self.dialStatus.valueOr:
return Unknown
if dialStatus == Unused:
return Unknown
if dialStatus == EDialError:
return NotReachable
if dialStatus == EDialBackError:
return NotReachable
return Reachable
proc asAutonatV2Response*(
self: DialResponse, testAddrs: seq[MultiAddress]
): AutonatV2Response =
let addrIdx = self.addrIdx.valueOr:
return AutonatV2Response(
reachability: self.asNetworkReachability(),
dialResp: self,
addrs: Opt.none(MultiAddress),
)
AutonatV2Response(
reachability: self.asNetworkReachability(),
dialResp: self,
addrs: Opt.some(testAddrs[addrIdx]),
)

View File

@@ -1,6 +1,7 @@
import chronos
import chronicles
import sequtils
import sets
import ../../peerid
import ./consts
import ./xordistance
@@ -9,22 +10,95 @@ import ./lookupstate
import ./requests
import ./keys
import ../protocol
import ../../switch
import ./protobuf
import ../../switch
import ../../multihash
import ../../utils/heartbeat
import std/[options, tables]
import std/[times, options, tables]
import results
logScope:
topics = "kad-dht"
type EntryKey* = object
data: seq[byte]
proc init*(T: typedesc[EntryKey], inner: seq[byte]): EntryKey {.gcsafe, raises: [].} =
EntryKey(data: inner)
type EntryValue* = object
data*: seq[byte] # public because needed for tests
proc init*(
T: typedesc[EntryValue], inner: seq[byte]
): EntryValue {.gcsafe, raises: [].} =
EntryValue(data: inner)
type TimeStamp* = object
# Currently a string, because for some reason, that's what is chosen at the protobuf level
# TODO: convert between RFC3339 strings and use of integers (i.e. the _correct_ way)
ts*: string # only public because needed for tests
type EntryRecord* = object
value*: EntryValue # only public because needed for tests
time*: TimeStamp # only public because needed for tests
proc init*(
T: typedesc[EntryRecord], value: EntryValue, time: Option[TimeStamp]
): EntryRecord {.gcsafe, raises: [].} =
EntryRecord(value: value, time: time.get(TimeStamp(ts: $times.now().utc)))
type LocalTable* = object
entries*: Table[EntryKey, EntryRecord] # public because needed for tests
proc init(self: typedesc[LocalTable]): LocalTable {.raises: [].} =
LocalTable()
type EntryCandidate* = object
key*: EntryKey
value*: EntryValue
type ValidatedEntry* = object
key: EntryKey
value: EntryValue
proc init*(
T: typedesc[ValidatedEntry], key: EntryKey, value: EntryValue
): ValidatedEntry {.gcsafe, raises: [].} =
ValidatedEntry(key: key, value: value)
type EntryValidator* = ref object of RootObj
method isValid*(
self: EntryValidator, key: EntryKey, val: EntryValue
): bool {.base, raises: [], gcsafe.} =
doAssert(false, "unimplimented base method")
type EntrySelector* = ref object of RootObj
method select*(
self: EntrySelector, cand: EntryRecord, others: seq[EntryRecord]
): Result[EntryRecord, string] {.base, raises: [], gcsafe.} =
doAssert(false, "EntrySelection base not implemented")
type KadDHT* = ref object of LPProtocol
switch: Switch
rng: ref HmacDrbgContext
rtable*: RoutingTable
maintenanceLoop: Future[void]
dataTable*: LocalTable
entryValidator: EntryValidator
entrySelector: EntrySelector
proc insert*(
self: var LocalTable, value: sink ValidatedEntry, time: TimeStamp
) {.raises: [].} =
debug "local table insertion", key = value.key.data, value = value.value.data
self.entries[value.key] = EntryRecord(value: value.value, time: time)
const MaxMsgSize = 4096
# Forward declaration
proc findNode*(
kad: KadDHT, targetId: Key
): Future[seq[PeerId]] {.async: (raises: [CancelledError]).}
proc sendFindNode(
kad: KadDHT, peerId: PeerId, addrs: seq[MultiAddress], targetId: Key
@@ -36,16 +110,13 @@ proc sendFindNode(
await kad.switch.dial(peerId, KadCodec)
else:
await kad.switch.dial(peerId, addrs, KadCodec)
defer:
await conn.close()
let msg = Message(msgType: MessageType.findNode, key: some(targetId.getBytes()))
await conn.writeLp(msg.encode().buffer)
let reply = Message.decode(await conn.readLp(MaxMsgSize)).tryGet()
if reply.msgType != MessageType.findNode:
raise newException(ValueError, "unexpected message type in reply: " & $reply)
@@ -68,16 +139,75 @@ proc waitRepliesOrTimeouts(
return (receivedReplies, failedPeers)
proc dispatchPutVal(
kad: KadDHT, peer: PeerId, entry: ValidatedEntry
): Future[void] {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
let conn = await kad.switch.dial(peer, KadCodec)
defer:
await conn.close()
let msg = Message(
msgType: MessageType.putValue,
record: some(Record(key: some(entry.key.data), value: some(entry.value.data))),
)
await conn.writeLp(msg.encode().buffer)
let reply = Message.decode(await conn.readLp(MaxMsgSize)).valueOr:
# todo log this more meaningfully
error "putValue reply decode fail", error = error, conn = conn
return
if reply != msg:
error "unexpected change between msg and reply: ",
msg = msg, reply = reply, conn = conn
proc putValue*(
kad: KadDHT, entKey: EntryKey, value: EntryValue, timeout: Option[int]
): Future[Result[void, string]] {.async: (raises: [CancelledError]), gcsafe.} =
if not kad.entryValidator.isValid(entKey, value):
return err("invalid key/value pair")
let others: seq[EntryRecord] =
if entKey in kad.dataTable.entries:
@[kad.dataTable.entries.getOrDefault(entKey)]
else:
@[]
let candAsRec = EntryRecord.init(value, none(TimeStamp))
let confirmedRec = kad.entrySelector.select(candAsRec, others).valueOr:
error "application provided selector error (local)", msg = error
return err(error)
trace "local putval", candidate = candAsRec, others = others, selected = confirmedRec
let validEnt = ValidatedEntry.init(entKey, confirmedRec.value)
let peers = await kad.findNode(entKey.data.toKey())
# We first prime the sends so the data is ready to go
let rpcBatch = peers.mapIt(kad.dispatchPutVal(it, validEnt))
# then we do the `move`, as insert takes the data as `sink`
kad.dataTable.insert(validEnt, confirmedRec.time)
try:
# now that the all the data is where it needs to be in memory, we can dispatch the
# RPCs
await rpcBatch.allFutures().wait(chronos.seconds(timeout.get(5)))
# It's quite normal for the dispatch to timeout, as it would require all calls to get
# their response. Downstream users may desire some sort of functionality in the
# future to get rpc telemetry, but in the meantime, we just move on...
except AsyncTimeoutError:
discard
return results.ok()
# Helper function forward declaration
proc checkConvergence(state: LookupState, me: PeerId): bool {.raises: [], gcsafe.}
proc findNode*(
kad: KadDHT, targetId: Key
): Future[seq[PeerId]] {.async: (raises: [CancelledError]).} =
## Node lookup. Iteratively search for the k closest peers to a target ID.
## Not necessarily will return the target itself
#debug "findNode", target = target
# TODO: should it return a single peer instead? read spec
var initialPeers = kad.rtable.findClosestPeers(targetId, DefaultReplic)
var state = LookupState.init(targetId, initialPeers)
var state = LookupState.init(targetId, initialPeers, kad.rtable.hasher)
var addrTable: Table[PeerId, seq[MultiAddress]] =
initTable[PeerId, seq[MultiAddress]]()
@@ -94,7 +224,7 @@ proc findNode*(
pendingFutures[peer] = kad
.sendFindNode(peer, addrTable.getOrDefault(peer, @[]), targetId)
.wait(5.seconds)
.wait(chronos.seconds(5))
state.activeQueries.inc
@@ -102,14 +232,24 @@ proc findNode*(
for msg in successfulReplies:
for peer in msg.closerPeers:
addrTable[PeerId.init(peer.id).get()] = peer.addrs
let pid = PeerId.init(peer.id)
if not pid.isOk:
error "PeerId init went bad. this is unusual", data = peer.id
continue
addrTable[pid.get()] = peer.addrs
state.updateShortlist(
msg,
proc(p: PeerInfo) =
discard kad.rtable.insert(p.peerId)
kad.switch.peerStore[AddressBook][p.peerId] = p.addrs
# Nodes might return different addresses for a peer, so we append instead of replacing
var existingAddresses =
kad.switch.peerStore[AddressBook][p.peerId].toHashSet()
for a in p.addrs:
existingAddresses.incl(a)
kad.switch.peerStore[AddressBook][p.peerId] = existingAddresses.toSeq()
# TODO: add TTL to peerstore, otherwise we can spam it with junk
,
kad.rtable.hasher,
)
for timedOut in timedOutPeers:
@@ -120,6 +260,25 @@ proc findNode*(
return state.selectClosestK()
proc findPeer*(
kad: KadDHT, peer: PeerId
): Future[Result[PeerInfo, string]] {.async: (raises: [CancelledError]).} =
## Walks the key space until it finds candidate addresses for a peer Id
if kad.switch.peerInfo.peerId == peer:
# Looking for yourself.
return ok(kad.switch.peerInfo)
if kad.switch.isConnected(peer):
# Return known info about already connected peer
return ok(PeerInfo(peerId: peer, addrs: kad.switch.peerStore[AddressBook][peer]))
let foundNodes = await kad.findNode(peer.toKey())
if not foundNodes.contains(peer):
return err("peer not found")
return ok(PeerInfo(peerId: peer, addrs: kad.switch.peerStore[AddressBook][peer]))
proc checkConvergence(state: LookupState, me: PeerId): bool {.raises: [], gcsafe.} =
let ready = state.activeQueries == 0
let noNew = selectAlphaPeers(state).filterIt(me != it).len == 0
@@ -132,28 +291,38 @@ proc bootstrap*(
try:
await kad.switch.connect(b.peerId, b.addrs)
debug "connected to bootstrap peer", peerId = b.peerId
except CatchableError as e:
error "failed to connect to bootstrap peer", peerId = b.peerId, error = e.msg
except DialFailedError as e:
# at some point will want to bubble up a Result[void, SomeErrorEnum]
error "failed to dial to bootstrap peer", peerId = b.peerId, error = e.msg
continue
try:
let msg =
await kad.sendFindNode(b.peerId, b.addrs, kad.rtable.selfId).wait(5.seconds)
for peer in msg.closerPeers:
let p = PeerId.init(peer.id).tryGet()
discard kad.rtable.insert(p)
let msg =
try:
await kad.sendFindNode(b.peerId, b.addrs, kad.rtable.selfId).wait(
chronos.seconds(5)
)
except CatchableError as e:
debug "send find node exception during bootstrap",
target = b.peerId, addrs = b.addrs, err = e.msg
continue
for peer in msg.closerPeers:
let p = PeerId.init(peer.id).valueOr:
debug "invalid peer id received", error = error
continue
discard kad.rtable.insert(p)
try:
kad.switch.peerStore[AddressBook][p] = peer.addrs
except:
error "this is here because an ergonomic means of keying into a table without exceptions is unknown"
# bootstrap node replied succesfully. Adding to routing table
discard kad.rtable.insert(b.peerId)
except CatchableError as e:
error "bootstrap failed for peer", peerId = b.peerId, exc = e.msg
# bootstrap node replied succesfully. Adding to routing table
discard kad.rtable.insert(b.peerId)
try:
# Adding some random node to prepopulate the table
discard await kad.findNode(PeerId.random(kad.rng).tryGet().toKey())
info "bootstrap lookup complete"
except CatchableError as e:
error "bootstrap lookup failed", error = e.msg
let key = PeerId.random(kad.rng).valueOr:
doAssert(false, "this should never happen")
return
discard await kad.findNode(key.toKey())
info "bootstrap lookup complete"
proc refreshBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
for i in 0 ..< kad.rtable.buckets.len:
@@ -162,49 +331,121 @@ proc refreshBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
discard await kad.findNode(randomKey)
proc maintainBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
heartbeat "refresh buckets", 10.minutes:
heartbeat "refresh buckets", chronos.minutes(10):
await kad.refreshBuckets()
proc new*(
T: typedesc[KadDHT], switch: Switch, rng: ref HmacDrbgContext = newRng()
T: typedesc[KadDHT],
switch: Switch,
validator: EntryValidator,
entrySelector: EntrySelector,
rng: ref HmacDrbgContext = newRng(),
): T {.raises: [].} =
var rtable = RoutingTable.init(switch.peerInfo.peerId.toKey())
let kad = T(rng: rng, switch: switch, rtable: rtable)
var rtable = RoutingTable.init(switch.peerInfo.peerId.toKey(), Opt.none(XorDHasher))
let kad = T(
rng: rng,
switch: switch,
rtable: rtable,
dataTable: LocalTable.init(),
entryValidator: validator,
entrySelector: entrySelector,
)
kad.codec = KadCodec
kad.handler = proc(
conn: Connection, proto: string
) {.async: (raises: [CancelledError]).} =
try:
while not conn.atEof:
let
buf = await conn.readLp(MaxMsgSize)
msg = Message.decode(buf).tryGet()
case msg.msgType
of MessageType.findNode:
let targetIdBytes = msg.key.get()
let targetId = PeerId.init(targetIdBytes).tryGet()
let closerPeers = kad.rtable.findClosest(targetId.toKey(), DefaultReplic)
let responsePb = encodeFindNodeReply(closerPeers, switch)
await conn.writeLp(responsePb.buffer)
# Peer is useful. adding to rtable
discard kad.rtable.insert(conn.peerId)
else:
raise newException(LPError, "unhandled kad-dht message type")
except CancelledError as exc:
raise exc
except CatchableError:
discard
# TODO: figure out why this fails:
# error "could not handle request",
# peerId = conn.PeerId, err = getCurrentExceptionMsg()
finally:
defer:
await conn.close()
while not conn.atEof:
let buf =
try:
await conn.readLp(MaxMsgSize)
except LPStreamError as e:
debug "Read error when handling kademlia RPC", conn = conn, err = e.msg
return
let msg = Message.decode(buf).valueOr:
debug "msg decode error handling kademlia RPC", err = error
return
case msg.msgType
of MessageType.findNode:
let targetIdBytes = msg.key.valueOr:
error "findNode message without key data present", msg = msg, conn = conn
return
let targetId = PeerId.init(targetIdBytes).valueOr:
error "findNode message without valid key data", msg = msg, conn = conn
return
let closerPeers = kad.rtable
.findClosest(targetId.toKey(), DefaultReplic)
# exclude the node requester because telling a peer about itself does not reduce the distance,
.filterIt(it != conn.peerId.toKey())
let responsePb = encodeFindNodeReply(closerPeers, switch)
try:
await conn.writeLp(responsePb.buffer)
except LPStreamError as e:
debug "write error when writing kad find-node RPC reply",
conn = conn, err = e.msg
return
# Peer is useful. adding to rtable
discard kad.rtable.insert(conn.peerId)
of MessageType.putValue:
let record = msg.record.valueOr:
error "no record in message buffer", msg = msg, conn = conn
return
let (skey, svalue) =
if record.key.isSome() and record.value.isSome():
(record.key.unsafeGet(), record.value.unsafeGet())
else:
error "no key or no value in rpc buffer", msg = msg, conn = conn
return
let key = EntryKey.init(skey)
let value = EntryValue.init(svalue)
# Value sanitisation done. Start insertion process
if not kad.entryValidator.isValid(key, value):
return
let others =
if kad.dataTable.entries.contains(key):
# need to do this shenans in order to avoid exceptions.
@[kad.dataTable.entries.getOrDefault(key)]
else:
@[]
let candRec = EntryRecord.init(value, none(TimeStamp))
let selectedRec = kad.entrySelector.select(candRec, others).valueOr:
error "application provided selector error", msg = error, conn = conn
return
trace "putval handler selection",
cand = candRec, others = others, selected = selectedRec
# Assume that if selection goes with another value, that it is valid
let validated = ValidatedEntry(key: key, value: selectedRec.value)
kad.dataTable.insert(validated, selectedRec.time)
# consistent with following link, echo message without change
# https://github.com/libp2p/js-libp2p/blob/cf9aab5c841ec08bc023b9f49083c95ad78a7a07/packages/kad-dht/src/rpc/handlers/put-value.ts#L22
try:
await conn.writeLp(buf)
except LPStreamError as e:
debug "write error when writing kad find-node RPC reply",
conn = conn, err = e.msg
return
else:
error "unhandled kad-dht message type", msg = msg
return
return kad
proc setSelector*(kad: KadDHT, selector: EntrySelector) =
doAssert(selector != nil)
kad.entrySelector = selector
proc setValidator*(kad: KadDHT, validator: EntryValidator) =
doAssert(validator != nil)
kad.entryValidator = validator
method start*(
kad: KadDHT
): Future[void] {.async: (raises: [CancelledError], raw: true).} =

View File

@@ -1,11 +1,10 @@
import nimcrypto/sha2
import ../../peerid
import ./consts
import chronicles
import stew/byteutils
type
KeyType* {.pure.} = enum
Unhashed
Raw
PeerId
@@ -13,15 +12,11 @@ type
case kind*: KeyType
of KeyType.PeerId:
peerId*: PeerId
of KeyType.Raw, KeyType.Unhashed:
data*: array[IdLength, byte]
of KeyType.Raw:
data*: seq[byte]
proc toKey*(s: seq[byte]): Key =
doAssert s.len == IdLength
var data: array[IdLength, byte]
for i in 0 ..< IdLength:
data[i] = s[i]
return Key(kind: KeyType.Raw, data: data)
return Key(kind: KeyType.Raw, data: s)
proc toKey*(p: PeerId): Key =
return Key(kind: KeyType.PeerId, peerId: p)
@@ -36,7 +31,7 @@ proc getBytes*(k: Key): seq[byte] =
case k.kind
of KeyType.PeerId:
k.peerId.getBytes()
of KeyType.Raw, KeyType.Unhashed:
of KeyType.Raw:
@(k.data)
template `==`*(a, b: Key): bool =
@@ -46,7 +41,7 @@ proc shortLog*(k: Key): string =
case k.kind
of KeyType.PeerId:
"PeerId:" & $k.peerId
of KeyType.Raw, KeyType.Unhashed:
of KeyType.Raw:
$k.kind & ":" & toHex(k.data)
chronicles.formatIt(Key):

View File

@@ -27,7 +27,10 @@ proc alreadyInShortlist(state: LookupState, peer: Peer): bool =
return state.shortlist.anyIt(it.peerId.getBytes() == peer.id)
proc updateShortlist*(
state: var LookupState, msg: Message, onInsert: proc(p: PeerInfo) {.gcsafe.}
state: var LookupState,
msg: Message,
onInsert: proc(p: PeerInfo) {.gcsafe.},
hasher: Opt[XorDHasher],
) =
for newPeer in msg.closerPeers.filterIt(not alreadyInShortlist(state, it)):
let peerInfo = PeerInfo(peerId: PeerId.init(newPeer.id).get(), addrs: newPeer.addrs)
@@ -36,7 +39,7 @@ proc updateShortlist*(
state.shortlist.add(
LookupNode(
peerId: peerInfo.peerId,
distance: xorDistance(peerInfo.peerId, state.targetId),
distance: xorDistance(peerInfo.peerId, state.targetId, hasher),
queried: false,
pending: false,
failed: false,
@@ -77,7 +80,12 @@ proc selectAlphaPeers*(state: LookupState): seq[PeerId] =
break
return selected
proc init*(T: type LookupState, targetId: Key, initialPeers: seq[PeerId]): T =
proc init*(
T: type LookupState,
targetId: Key,
initialPeers: seq[PeerId],
hasher: Opt[XorDHasher],
): T =
var res = LookupState(
targetId: targetId,
shortlist: @[],
@@ -90,7 +98,7 @@ proc init*(T: type LookupState, targetId: Key, initialPeers: seq[PeerId]): T =
res.shortlist.add(
LookupNode(
peerId: p,
distance: xorDistance(p, targetId),
distance: xorDistance(p, targetId, hasher),
queried: false,
pending: false,
failed: false,

View File

@@ -8,6 +8,7 @@ import ./xordistance
import ../../peerid
import sequtils
import ../../utils/sequninit
import results
logScope:
topics = "kad-dht rtable"
@@ -23,15 +24,16 @@ type
RoutingTable* = ref object
selfId*: Key
buckets*: seq[Bucket]
hasher*: Opt[XorDHasher]
proc `$`*(rt: RoutingTable): string =
"selfId(" & $rt.selfId & ") buckets(" & $rt.buckets & ")"
proc init*(T: typedesc[RoutingTable], selfId: Key): T =
return RoutingTable(selfId: selfId, buckets: @[])
proc init*(T: typedesc[RoutingTable], selfId: Key, hasher: Opt[XorDHasher]): T =
return RoutingTable(selfId: selfId, buckets: @[], hasher: hasher)
proc bucketIndex*(selfId, key: Key): int =
return xorDistance(selfId, key).leadingZeros
proc bucketIndex*(selfId, key: Key, hasher: Opt[XorDHasher]): int =
return xorDistance(selfId, key, hasher).leadingZeros
proc peerIndexInBucket(bucket: var Bucket, nodeId: Key): Opt[int] =
for i, p in bucket.peers:
@@ -43,7 +45,7 @@ proc insert*(rtable: var RoutingTable, nodeId: Key): bool =
if nodeId == rtable.selfId:
return false # No self insertion
let idx = bucketIndex(rtable.selfId, nodeId)
let idx = bucketIndex(rtable.selfId, nodeId, rtable.hasher)
if idx >= maxBuckets:
trace "cannot insert node. max buckets have been reached",
nodeId, bucketIdx = idx, maxBuckets
@@ -80,7 +82,9 @@ proc findClosest*(rtable: RoutingTable, targetId: Key, count: int): seq[Key] =
allNodes.sort(
proc(a, b: Key): int =
cmp(xorDistance(a, targetId), xorDistance(b, targetId))
cmp(
xorDistance(a, targetId, rtable.hasher), xorDistance(b, targetId, rtable.hasher)
)
)
return allNodes[0 ..< min(count, allNodes.len)]

View File

@@ -1,9 +1,27 @@
import ./consts
import stew/arrayOps
import ./keys
import nimcrypto/sha2
import ../../peerid
import results
type XorDistance* = array[IdLength, byte]
type XorDHasher* = proc(input: seq[byte]): array[IdLength, byte] {.
raises: [], nimcall, noSideEffect, gcsafe
.}
proc defaultHasher(
input: seq[byte]
): array[IdLength, byte] {.raises: [], nimcall, noSideEffect, gcsafe.} =
return sha256.digest(input).data
# useful for testing purposes
proc noOpHasher*(
input: seq[byte]
): array[IdLength, byte] {.raises: [], nimcall, noSideEffect, gcsafe.} =
var data: array[IdLength, byte]
discard data.copyFrom(input)
return data
proc countLeadingZeroBits*(b: byte): int =
for i in 0 .. 7:
@@ -31,25 +49,23 @@ proc `<`*(a, b: XorDistance): bool =
proc `<=`*(a, b: XorDistance): bool =
cmp(a, b) <= 0
proc hashFor(k: Key): seq[byte] =
proc hashFor(k: Key, hasher: Opt[XorDHasher]): seq[byte] =
return
@(
case k.kind
of KeyType.PeerId:
sha256.digest(k.peerId.getBytes()).data
hasher.get(defaultHasher)(k.peerId.getBytes())
of KeyType.Raw:
sha256.digest(k.data).data
of KeyType.Unhashed:
k.data
hasher.get(defaultHasher)(k.data)
)
proc xorDistance*(a, b: Key): XorDistance =
let hashA = a.hashFor()
let hashB = b.hashFor()
proc xorDistance*(a, b: Key, hasher: Opt[XorDHasher]): XorDistance =
let hashA = a.hashFor(hasher)
let hashB = b.hashFor(hasher)
var response: XorDistance
for i in 0 ..< hashA.len:
response[i] = hashA[i] xor hashB[i]
return response
proc xorDistance*(a: PeerId, b: Key): XorDistance =
xorDistance(a.toKey(), b)
proc xorDistance*(a: PeerId, b: Key, hasher: Opt[XorDHasher]): XorDistance =
xorDistance(a.toKey(), b, hasher)

View File

@@ -12,8 +12,6 @@
import chronos, chronicles, sequtils
import stew/endians2
import ./core, ../../stream/connection
when defined(libp2p_quic_support):
import ../../transports/quictransport
logScope:
topics = "libp2p perf"
@@ -59,13 +57,8 @@ proc perf*(
statsCopy.uploadBytes += toWrite.uint
p.stats = statsCopy
# Close connection after writing for TCP, but not for QUIC
when defined(libp2p_quic_support):
if not (conn of QuicStream):
await conn.close()
# For QUIC streams, don't close yet - let server manage lifecycle
else:
await conn.close()
# Close write side of the stream (half-close) to signal EOF to server
await conn.closeWrite()
size = sizeToRead
@@ -80,10 +73,8 @@ proc perf*(
statsCopy.downloadBytes += toRead.uint
p.stats = statsCopy
# Close QUIC connections after read phase
when defined(libp2p_quic_support):
if conn of QuicStream:
await conn.close()
# Close the connection after reading
await conn.close()
except CancelledError as e:
raise e
except LPStreamError as e:

View File

@@ -14,8 +14,6 @@
import chronos, chronicles
import stew/endians2
import ./core, ../protocol, ../../stream/connection, ../../utility
when defined(libp2p_quic_support):
import ../../transports/quictransport
export chronicles, connection
@@ -26,50 +24,29 @@ type Perf* = ref object of LPProtocol
proc new*(T: typedesc[Perf]): T {.public.} =
var p = T()
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
var bytesRead = 0
try:
trace "Received benchmark performance check", conn
var
sizeBuffer: array[8, byte]
size: uint64
await conn.readExactly(addr sizeBuffer[0], 8)
size = uint64.fromBytesBE(sizeBuffer)
var toReadBuffer: array[PerfSize, byte]
try:
# Different handling for QUIC vs TCP streams
when defined(libp2p_quic_support):
if conn of QuicStream:
# QUIC needs timeout-based approach to detect end of upload
while not conn.atEof:
let readFut = conn.readOnce(addr toReadBuffer[0], PerfSize)
let read = readFut.read()
if read == 0:
break
bytesRead += read
else:
# TCP streams handle EOF properly
while true:
let read = await conn.readOnce(addr toReadBuffer[0], PerfSize)
if read == 0:
break
bytesRead += read
else:
# TCP streams handle EOF properly
while true:
let read = await conn.readOnce(addr toReadBuffer[0], PerfSize)
if read == 0:
break
bytesRead += read
except CatchableError:
discard
var uploadSizeBuffer: array[8, byte]
await conn.readExactly(addr uploadSizeBuffer[0], 8)
var uploadSize = uint64.fromBytesBE(uploadSizeBuffer)
var buf: array[PerfSize, byte]
while size > 0:
let toWrite = min(size, PerfSize)
await conn.write(buf[0 ..< toWrite])
size -= toWrite
var readBuffer: array[PerfSize, byte]
while not conn.atEof:
try:
let readBytes = await conn.readOnce(addr readBuffer[0], PerfSize)
if readBytes == 0:
break
except LPStreamEOFError:
break
var writeBuffer: array[PerfSize, byte]
while uploadSize > 0:
let toWrite = min(uploadSize, PerfSize)
await conn.write(writeBuffer[0 ..< toWrite])
uploadSize -= toWrite
except CancelledError as exc:
trace "cancelled perf handler"
raise exc

View File

@@ -725,9 +725,8 @@ method rpcHandler*(
continue
if (msg.signature.len > 0 or g.verifySignature) and not msg.verify():
# always validate if signature is present or required
debug "Dropping message due to failed signature verification",
msgId = shortLog(msgId), peer
debug "Dropping message due to failed signature verification", msg = msg
await g.punishInvalidMessage(peer, msg)
continue

View File

@@ -478,19 +478,16 @@ iterator splitRPCMsg(
## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send.
## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes.
var currentRPCMsg = rpcMsg
currentRPCMsg.messages = newSeq[Message]()
var currentSize = byteSize(currentRPCMsg)
var currentRPCMsg = RPCMsg()
var currentSize = 0
for msg in rpcMsg.messages:
let msgSize = byteSize(msg)
# Check if adding the next message will exceed maxSize
if float(currentSize + msgSize) * 1.1 > float(maxSize):
# Guessing 10% protobuf overhead
if currentRPCMsg.messages.len == 0:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
if currentSize + msgSize > maxSize:
if msgSize > maxSize:
warn "message too big to sent", peer, rpcMsg = shortLog(msg)
continue # Skip this message
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
@@ -502,11 +499,9 @@ iterator splitRPCMsg(
currentSize += msgSize
# Check if there is a non-empty currentRPCMsg left to be added
if currentSize > 0 and currentRPCMsg.messages.len > 0:
if currentRPCMsg.messages.len > 0:
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
yield encodeRpcMsg(currentRPCMsg, anonymize)
else:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
proc send*(
p: PubSubPeer,
@@ -542,13 +537,8 @@ proc send*(
sendMetrics(msg)
encodeRpcMsg(msg, anonymize)
if encoded.len > p.maxMessageSize and msg.messages.len > 1:
for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize):
asyncSpawn p.sendEncoded(encodedSplitMsg, isHighPriority, useCustomConn)
else:
# If the message size is within limits, send it as is
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
asyncSpawn p.sendEncoded(encoded, isHighPriority, useCustomConn)
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
asyncSpawn p.sendEncoded(encoded, isHighPriority, useCustomConn)
proc canAskIWant*(p: PubSubPeer, msgId: MessageId): bool =
for sentIHave in p.sentIHaves.mitems():

View File

@@ -41,15 +41,36 @@ func defaultMsgIdProvider*(m: Message): Result[MessageId, ValidationResult] =
proc sign*(msg: Message, privateKey: PrivateKey): CryptoResult[seq[byte]] =
ok((?privateKey.sign(PubSubPrefix & encodeMessage(msg, false))).getBytes())
proc extractPublicKey(m: Message): Opt[PublicKey] =
var pubkey: PublicKey
if m.fromPeer.hasPublicKey() and m.fromPeer.extractPublicKey(pubkey):
Opt.some(pubkey)
elif m.key.len > 0 and pubkey.init(m.key):
# check if peerId extracted from m.key is the same as m.fromPeer
let derivedPeerId = PeerId.init(pubkey).valueOr:
warn "could not derive peerId from key field"
return Opt.none(PublicKey)
if derivedPeerId != m.fromPeer:
warn "peerId derived from msg.key is not the same as msg.fromPeer",
derivedPeerId = derivedPeerId, fromPeer = m.fromPeer
return Opt.none(PublicKey)
Opt.some(pubkey)
else:
Opt.none(PublicKey)
proc verify*(m: Message): bool =
if m.signature.len > 0 and m.key.len > 0:
if m.signature.len > 0:
var msg = m
msg.signature = @[]
msg.key = @[]
var remote: Signature
var key: PublicKey
if remote.init(m.signature) and key.init(m.key):
let key = m.extractPublicKey().valueOr:
warn "could not extract public key", msg = m
return false
if remote.init(m.signature):
trace "verifying signature", remoteSignature = remote
result = remote.verify(PubSubPrefix & encodeMessage(msg, false), key)

View File

@@ -35,12 +35,15 @@ declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
const
RendezVousCodec* = "/rendezvous/1.0.0"
# Default minimum TTL per libp2p spec
MinimumDuration* = 2.hours
# Lower validation limit to accommodate Waku requirements
MinimumAcceptedDuration = 1.minutes
MaximumDuration = 72.hours
MaximumMessageLen = 1 shl 22 # 4MB
MinimumNamespaceLen = 1
MaximumNamespaceLen = 255
RegistrationLimitPerPeer = 1000
RegistrationLimitPerPeer* = 1000
DiscoverLimit = 1000'u64
SemaphoreDefaultSize = 5
@@ -69,7 +72,7 @@ type
Register = object
ns: string
signedPeerRecord: seq[byte]
ttl: Opt[uint64] # in seconds
ttl*: Opt[uint64] # in seconds
RegisterResponse = object
status: ResponseStatus
@@ -310,27 +313,27 @@ proc decode(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
type
RendezVousError* = object of DiscoveryError
RegisteredData = object
expiration: Moment
peerId: PeerId
data: Register
expiration*: Moment
peerId*: PeerId
data*: Register
RendezVous* = ref object of LPProtocol
# Registered needs to be an offsetted sequence
# because we need stable index for the cookies.
registered: OffsettedSeq[RegisteredData]
registered*: OffsettedSeq[RegisteredData]
# Namespaces is a table whose key is a salted namespace and
# the value is the index sequence corresponding to this
# namespace in the offsettedqueue.
namespaces: Table[string, seq[int]]
namespaces*: Table[string, seq[int]]
rng: ref HmacDrbgContext
salt: string
defaultDT: Moment
expiredDT: Moment
registerDeletionLoop: Future[void]
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
# + make the heartbeat sleep duration "smarter"
sema: AsyncSemaphore
peers: seq[PeerId]
cookiesSaved: Table[PeerId, Table[string, seq[byte]]]
cookiesSaved*: Table[PeerId, Table[string, seq[byte]]]
switch: Switch
minDuration: Duration
maxDuration: Duration
@@ -394,9 +397,8 @@ proc sendDiscoverResponseError(
await conn.writeLp(msg.buffer)
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
let n = Moment.now()
for data in rdv.registered:
if data.peerId == peerId and data.expiration > n:
if data.peerId == peerId:
result.inc()
proc save(
@@ -409,7 +411,7 @@ proc save(
if rdv.registered[index].peerId == peerId:
if update == false:
return
rdv.registered[index].expiration = rdv.defaultDT
rdv.registered[index].expiration = rdv.expiredDT
rdv.registered.add(
RegisteredData(
peerId: peerId,
@@ -446,7 +448,7 @@ proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == conn.peerId:
rdv.registered[index].expiration = rdv.defaultDT
rdv.registered[index].expiration = rdv.expiredDT
libp2p_rendezvous_registered.dec()
except KeyError:
return
@@ -468,11 +470,17 @@ proc discover(
await conn.sendDiscoverResponseError(InvalidCookie)
return
else:
Cookie(offset: rdv.registered.low().uint64 - 1)
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get() or
cookie.offset < rdv.registered.low().uint64 or
cookie.offset > rdv.registered.high().uint64:
cookie = Cookie(offset: rdv.registered.low().uint64 - 1)
# Start from the current lowest index (inclusive)
Cookie(offset: rdv.registered.low().uint64)
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get():
# Namespace changed: start from the beginning of that namespace
cookie = Cookie(offset: rdv.registered.low().uint64)
elif cookie.offset < rdv.registered.low().uint64:
# Cookie behind available range: reset to current low
cookie.offset = rdv.registered.low().uint64
elif cookie.offset > (rdv.registered.high() + 1).uint64:
# Cookie ahead of available range: reset to one past current high (empty page)
cookie.offset = (rdv.registered.high() + 1).uint64
let namespaces =
if d.ns.isSome():
try:
@@ -485,21 +493,21 @@ proc discover(
if namespaces.len() == 0:
await conn.sendDiscoverResponse(@[], Cookie())
return
var offset = namespaces[^1]
var nextOffset = cookie.offset
let n = Moment.now()
var s = collect(newSeq()):
for index in namespaces:
var reg = rdv.registered[index]
if limit == 0:
offset = index
break
if reg.expiration < n or index.uint64 <= cookie.offset:
if reg.expiration < n or index.uint64 < cookie.offset:
continue
limit.dec()
nextOffset = index.uint64 + 1
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
reg.data
rdv.rng.shuffle(s)
await conn.sendDiscoverResponse(s, Cookie(offset: offset.uint64, ns: d.ns))
await conn.sendDiscoverResponse(s, Cookie(offset: nextOffset, ns: d.ns))
proc advertisePeer(
rdv: RendezVous, peer: PeerId, msg: seq[byte]
@@ -683,7 +691,7 @@ proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
rdv.registered[index].expiration = rdv.defaultDT
rdv.registered[index].expiration = rdv.expiredDT
except KeyError:
return
@@ -738,10 +746,10 @@ proc new*(
minDuration = MinimumDuration,
maxDuration = MaximumDuration,
): T {.raises: [RendezVousError].} =
if minDuration < 1.minutes:
if minDuration < MinimumAcceptedDuration:
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
if maxDuration > 72.hours:
if maxDuration > MaximumDuration:
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
if minDuration >= maxDuration:
@@ -754,8 +762,8 @@ proc new*(
let rdv = T(
rng: rng,
salt: string.fromBytes(generateBytes(rng[], 8)),
registered: initOffsettedSeq[RegisteredData](1),
defaultDT: Moment.now() - 1.days,
registered: initOffsettedSeq[RegisteredData](),
expiredDT: Moment.now() - 1.days,
#registerEvent: newAsyncEvent(),
sema: newAsyncSemaphore(SemaphoreDefaultSize),
minDuration: minDuration,
@@ -806,7 +814,7 @@ proc new*(
rdv.setup(switch)
return rdv
proc deletesRegister(
proc deletesRegister*(
rdv: RendezVous, interval = 1.minutes
) {.async: (raises: [CancelledError]).} =
heartbeat "Register timeout", interval:

View File

@@ -583,7 +583,8 @@ method handshake*(
)
conn.peerId = pid
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
var tmp =
NoiseConnection.new(conn, conn.peerId, conn.observedAddr, conn.localAddr)
if initiator:
tmp.readCs = handshakeRes.cs2
tmp.writeCs = handshakeRes.cs1

View File

@@ -51,12 +51,14 @@ proc new*(
conn: Connection,
peerId: PeerId,
observedAddr: Opt[MultiAddress],
localAddr: Opt[MultiAddress],
timeout: Duration = DefaultConnectionTimeout,
): T =
result = T(
stream: conn,
peerId: peerId,
observedAddr: observedAddr,
localAddr: localAddr,
closeEvent: conn.closeEvent,
timeout: timeout,
dir: conn.dir,

View File

@@ -62,8 +62,15 @@ proc init*(
dir: Direction,
timeout = DefaultChronosStreamTimeout,
observedAddr: Opt[MultiAddress],
localAddr: Opt[MultiAddress],
): ChronosStream =
result = C(client: client, timeout: timeout, dir: dir, observedAddr: observedAddr)
result = C(
client: client,
timeout: timeout,
dir: dir,
observedAddr: observedAddr,
localAddr: localAddr,
)
result.initStream()
template withExceptions(body: untyped) =
@@ -151,6 +158,19 @@ method closed*(s: ChronosStream): bool =
method atEof*(s: ChronosStream): bool =
s.client.atEof()
method closeWrite*(s: ChronosStream) {.async: (raises: []).} =
## Close the write side of the TCP connection using half-close
if not s.client.closed():
try:
await s.client.shutdownWait()
trace "Write side closed", address = $s.client.remoteAddress(), s
except TransportError:
# Ignore transport errors during shutdown
discard
except CatchableError:
# Ignore other errors during shutdown
discard
method closeImpl*(s: ChronosStream) {.async: (raises: []).} =
trace "Shutting down chronos stream", address = $s.client.remoteAddress(), s

View File

@@ -33,6 +33,7 @@ type
timeoutHandler*: TimeoutHandler # timeout handler
peerId*: PeerId
observedAddr*: Opt[MultiAddress]
localAddr*: Opt[MultiAddress]
protocol*: string # protocol used by the connection, used as metrics tag
transportDir*: Direction # underlying transport (usually socket) direction
when defined(libp2p_agents_metrics):
@@ -40,6 +41,12 @@ type
proc timeoutMonitor(s: Connection) {.async: (raises: []).}
method closeWrite*(s: Connection): Future[void] {.base, async: (raises: []).} =
## Close the write side of the connection
## Subclasses should implement this for their specific transport
## Default implementation just closes the entire connection
await s.close()
func shortLog*(conn: Connection): string =
try:
if conn == nil:
@@ -133,13 +140,17 @@ when defined(libp2p_agents_metrics):
var conn = s
while conn != nil:
conn.shortAgent = shortAgent
conn = conn.getWrapped()
let wrapped = conn.getWrapped()
if wrapped == conn:
break
conn = wrapped
proc new*(
C: type Connection,
peerId: PeerId,
dir: Direction,
observedAddr: Opt[MultiAddress],
observedAddr: Opt[MultiAddress] = Opt.none(MultiAddress),
localAddr: Opt[MultiAddress] = Opt.none(MultiAddress),
timeout: Duration = DefaultConnectionTimeout,
timeoutHandler: TimeoutHandler = nil,
): Connection =
@@ -149,6 +160,7 @@ proc new*(
timeout: timeout,
timeoutHandler: timeoutHandler,
observedAddr: observedAddr,
localAddr: localAddr,
)
result.initStream()

View File

@@ -36,14 +36,19 @@ type QuicStream* = ref object of P2PConnection
cached: seq[byte]
proc new(
_: type QuicStream, stream: Stream, oaddr: Opt[MultiAddress], peerId: PeerId
_: type QuicStream,
stream: Stream,
oaddr: Opt[MultiAddress],
laddr: Opt[MultiAddress],
peerId: PeerId,
): QuicStream =
let quicstream = QuicStream(stream: stream, observedAddr: oaddr, peerId: peerId)
let quicstream =
QuicStream(stream: stream, observedAddr: oaddr, localAddr: laddr, peerId: peerId)
procCall P2PConnection(quicstream).initStream()
quicstream
method getWrapped*(self: QuicStream): P2PConnection =
nil
self
template mapExceptions(body: untyped) =
try:
@@ -56,18 +61,23 @@ template mapExceptions(body: untyped) =
method readOnce*(
stream: QuicStream, pbytes: pointer, nbytes: int
): Future[int] {.async: (raises: [CancelledError, LPStreamError]).} =
try:
if stream.cached.len == 0:
if stream.cached.len == 0:
try:
stream.cached = await stream.stream.read()
if stream.cached.len == 0:
raise newLPStreamEOFError()
except CancelledError as exc:
raise exc
except LPStreamEOFError as exc:
raise exc
except CatchableError as exc:
raise (ref LPStreamError)(msg: "error in readOnce: " & exc.msg, parent: exc)
result = min(nbytes, stream.cached.len)
copyMem(pbytes, addr stream.cached[0], result)
stream.cached = stream.cached[result ..^ 1]
libp2p_network_bytes.inc(result.int64, labelValues = ["in"])
except CatchableError as exc:
raise newLPStreamEOFError()
let toRead = min(nbytes, stream.cached.len)
copyMem(pbytes, addr stream.cached[0], toRead)
stream.cached = stream.cached[toRead ..^ 1]
libp2p_network_bytes.inc(toRead.int64, labelValues = ["in"])
return toRead
{.push warning[LockLevel]: off.}
method write*(
@@ -78,6 +88,13 @@ method write*(
{.pop.}
method closeWrite*(stream: QuicStream) {.async: (raises: []).} =
## Close the write side of the QUIC stream
try:
await stream.stream.closeWrite()
except CatchableError as exc:
discard
method closeImpl*(stream: QuicStream) {.async: (raises: []).} =
try:
await stream.stream.close()
@@ -108,7 +125,11 @@ proc getStream*(
stream = await session.connection.openStream()
await stream.write(@[]) # QUIC streams do not exist until data is sent
let qs = QuicStream.new(stream, session.observedAddr, session.peerId)
let qs =
QuicStream.new(stream, session.observedAddr, session.localAddr, session.peerId)
when defined(libp2p_agents_metrics):
qs.shortAgent = session.shortAgent
session.streams.add(qs)
return qs
except CatchableError as exc:
@@ -116,13 +137,20 @@ proc getStream*(
raise (ref QuicTransportError)(msg: "error in getStream: " & exc.msg, parent: exc)
method getWrapped*(self: QuicSession): P2PConnection =
nil
self
# Muxer
type QuicMuxer = ref object of Muxer
quicSession: QuicSession
handleFut: Future[void]
when defined(libp2p_agents_metrics):
method setShortAgent*(m: QuicMuxer, shortAgent: string) =
m.quicSession.shortAgent = shortAgent
for s in m.quicSession.streams:
s.shortAgent = shortAgent
m.connection.shortAgent = shortAgent
method newStream*(
m: QuicMuxer, name: string = "", lazy: bool = false
): Future[P2PConnection] {.
@@ -141,7 +169,7 @@ proc handleStream(m: QuicMuxer, chann: QuicStream) {.async: (raises: []).} =
trace "finished handling stream"
doAssert(chann.closed, "connection not closed by handler!")
except CatchableError as exc:
trace "Exception in mplex stream handler", msg = exc.msg
trace "Exception in quic stream handler", msg = exc.msg
await chann.close()
method handle*(m: QuicMuxer): Future[void] {.async: (raises: []).} =
@@ -150,7 +178,7 @@ method handle*(m: QuicMuxer): Future[void] {.async: (raises: []).} =
let incomingStream = await m.quicSession.getStream(Direction.In)
asyncSpawn m.handleStream(incomingStream)
except CatchableError as exc:
trace "Exception in mplex handler", msg = exc.msg
trace "Exception in quic handler", msg = exc.msg
method close*(m: QuicMuxer) {.async: (raises: []).} =
try:
@@ -262,7 +290,8 @@ method start*(
method stop*(transport: QuicTransport) {.async: (raises: []).} =
if transport.running:
for c in transport.connections:
let conns = transport.connections[0 .. ^1]
for c in conns:
await c.close()
await procCall Transport(transport).stop()
try:
@@ -277,11 +306,17 @@ proc wrapConnection(
transport: QuicTransport, connection: QuicConnection
): QuicSession {.raises: [TransportOsError, MaError].} =
let
remoteAddr = connection.remoteAddress()
observedAddr =
MultiAddress.init(remoteAddr, IPPROTO_UDP).get() &
MultiAddress.init(connection.remoteAddress(), IPPROTO_UDP).get() &
MultiAddress.init("/quic-v1").get()
session = QuicSession(connection: connection, observedAddr: Opt.some(observedAddr))
localAddr =
MultiAddress.init(connection.localAddress(), IPPROTO_UDP).get() &
MultiAddress.init("/quic-v1").get()
session = QuicSession(
connection: connection,
observedAddr: Opt.some(observedAddr),
localAddr: Opt.some(localAddr),
)
session.initStream()
@@ -301,12 +336,12 @@ method accept*(
): Future[connection.Connection] {.
async: (raises: [transport.TransportError, CancelledError])
.} =
doAssert not self.listener.isNil, "call start() before calling accept()"
if not self.running:
# stop accept only when transport is stopped (not when error occurs)
raise newException(QuicTransportAcceptStopped, "Quic transport stopped")
doAssert not self.listener.isNil, "call start() before calling accept()"
try:
let connection = await self.listener.accept()
return self.wrapConnection(connection)

View File

@@ -47,6 +47,7 @@ proc connHandler*(
self: TcpTransport,
client: StreamTransport,
observedAddr: Opt[MultiAddress],
localAddr: Opt[MultiAddress],
dir: Direction,
): Connection =
trace "Handling tcp connection",
@@ -59,6 +60,7 @@ proc connHandler*(
client = client,
dir = dir,
observedAddr = observedAddr,
localAddr = localAddr,
timeout = self.connectionsTimeout,
)
)
@@ -267,18 +269,22 @@ method accept*(
safeCloseWait(transp)
raise newTransportClosedError()
let remote =
let (localAddr, observedAddr) =
try:
transp.remoteAddress
(
MultiAddress.init(transp.localAddress).expect(
"Can initialize from local address"
),
MultiAddress.init(transp.remoteAddress).expect(
"Can initialize from remote address"
),
)
except TransportOsError as exc:
# The connection had errors / was closed before `await` returned control
safeCloseWait(transp)
debug "Cannot read remote address", description = exc.msg
debug "Cannot read address", description = exc.msg
return nil
let observedAddr =
MultiAddress.init(remote).expect("Can initialize from remote address")
self.connHandler(transp, Opt.some(observedAddr), Direction.In)
self.connHandler(transp, Opt.some(observedAddr), Opt.some(localAddr), Direction.In)
method dial*(
self: TcpTransport,
@@ -320,14 +326,17 @@ method dial*(
safeCloseWait(transp)
raise newTransportClosedError()
let observedAddr =
let (observedAddr, localAddr) =
try:
MultiAddress.init(transp.remoteAddress).expect("remote address is valid")
(
MultiAddress.init(transp.remoteAddress).expect("remote address is valid"),
MultiAddress.init(transp.localAddress).expect("local address is valid"),
)
except TransportOsError as exc:
safeCloseWait(transp)
raise (ref TcpTransportError)(msg: "MultiAddress.init error in dial: " & exc.msg)
self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
self.connHandler(transp, Opt.some(observedAddr), Opt.some(localAddr), Direction.Out)
method handles*(t: TcpTransport, address: MultiAddress): bool {.raises: [].} =
if procCall Transport(t).handles(address):

View File

@@ -237,7 +237,9 @@ method dial*(
try:
transp = await connectToTorServer(self.transportAddress)
await dialPeer(transp, address)
return self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
return self.tcpTransport.connHandler(
transp, Opt.none(MultiAddress), Opt.none(MultiAddress), Direction.Out
)
except CancelledError as e:
safeCloseWait(transp)
raise e

View File

@@ -18,9 +18,9 @@ import
../multicodec,
../muxers/muxer,
../upgrademngrs/upgrade,
../protocols/connectivity/autonat/core
../protocols/connectivity/autonat/types
export core.NetworkReachability
export types.NetworkReachability
logScope:
topics = "libp2p transport"

View File

@@ -55,10 +55,16 @@ proc new*(
session: WSSession,
dir: Direction,
observedAddr: Opt[MultiAddress],
localAddr: Opt[MultiAddress],
timeout = 10.minutes,
): T =
let stream =
T(session: session, timeout: timeout, dir: dir, observedAddr: observedAddr)
let stream = T(
session: session,
timeout: timeout,
dir: dir,
observedAddr: observedAddr,
localAddr: localAddr,
)
stream.initStream()
return stream
@@ -212,11 +218,9 @@ method stop*(self: WsTransport) {.async: (raises: []).} =
trace "Stopping WS transport"
await procCall Transport(self).stop() # call base
checkFutures(
await allFinished(
self.connections[Direction.In].mapIt(it.close()) &
self.connections[Direction.Out].mapIt(it.close())
)
discard await allFinished(
self.connections[Direction.In].mapIt(it.close()) &
self.connections[Direction.Out].mapIt(it.close())
)
var toWait: seq[Future[void]]
@@ -241,9 +245,8 @@ proc connHandler(
self: WsTransport, stream: WSSession, secure: bool, dir: Direction
): Future[Connection] {.async: (raises: [CatchableError]).} =
## Returning CatchableError is fine because we later handle different exceptions.
##
let observedAddr =
let (observedAddr, localAddr) =
try:
let
codec =
@@ -252,15 +255,19 @@ proc connHandler(
else:
MultiAddress.init("/ws")
remoteAddr = stream.stream.reader.tsource.remoteAddress
localAddr = stream.stream.reader.tsource.localAddress
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet()
(
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet(),
MultiAddress.init(localAddr).tryGet() & codec.tryGet(),
)
except CatchableError as exc:
trace "Failed to create observedAddr", description = exc.msg
trace "Failed to create observedAddr or listenAddr", description = exc.msg
if not (isNil(stream) and stream.stream.reader.closed):
safeClose(stream)
raise exc
let conn = WsStream.new(stream, dir, Opt.some(observedAddr))
let conn = WsStream.new(stream, dir, Opt.some(observedAddr), Opt.some(localAddr))
self.connections[dir].add(conn)
proc onClose() {.async: (raises: []).} =

74
libp2p/utils/ipaddr.nim Normal file
View File

@@ -0,0 +1,74 @@
import net, strutils
import ../switch, ../multiaddress, ../multicodec
proc isIPv4*(ip: IpAddress): bool =
ip.family == IpAddressFamily.IPv4
proc isIPv6*(ip: IpAddress): bool =
ip.family == IpAddressFamily.IPv6
proc isPrivate*(ip: string): bool {.raises: [ValueError].} =
ip.startsWith("10.") or
(ip.startsWith("172.") and parseInt(ip.split(".")[1]) in 16 .. 31) or
ip.startsWith("192.168.") or ip.startsWith("127.") or ip.startsWith("169.254.")
proc isPrivate*(ip: IpAddress): bool {.raises: [ValueError].} =
isPrivate($ip)
proc isPublic*(ip: string): bool {.raises: [ValueError].} =
not isPrivate(ip)
proc isPublic*(ip: IpAddress): bool {.raises: [ValueError].} =
isPublic($ip)
proc getPublicIPAddress*(): IpAddress {.raises: [OSError, ValueError].} =
let ip =
try:
getPrimaryIPAddr()
except OSError as exc:
raise exc
except ValueError as exc:
raise exc
except Exception as exc:
raise newException(OSError, "Could not get primary IP address")
if not ip.isIPv4():
raise newException(ValueError, "Host does not have an IPv4 address")
if not ip.isPublic():
raise newException(ValueError, "Host does not have a public IPv4 address")
ip
proc ipAddrMatches*(
lookup: MultiAddress, addrs: seq[MultiAddress], ip4: bool = true
): bool =
## Checks ``lookup``'s IP is in any of addrs
let ipType =
if ip4:
multiCodec("ip4")
else:
multiCodec("ip6")
let lookup = lookup.getPart(ipType).valueOr:
return false
for ma in addrs:
ma[0].withValue(ipAddr):
if ipAddr == lookup:
return true
false
proc ipSupport*(addrs: seq[MultiAddress]): (bool, bool) =
## Returns ipv4 and ipv6 support status of a list of MultiAddresses
var ipv4 = false
var ipv6 = false
for ma in addrs:
ma[0].withValue(addrIp):
if IP4.match(addrIp):
ipv4 = true
elif IP6.match(addrIp):
ipv6 = true
(ipv4, ipv6)

View File

@@ -11,7 +11,8 @@ fi
# Clean up output
output_dir="$(pwd)/performance/output"
mkdir -p "$output_dir"
rm -f "$output_dir"/*.json
rm -rf "$output_dir"
mkdir -p "$output_dir/sync"
# Run Test Nodes
container_names=()

View File

@@ -12,6 +12,7 @@
import metrics
import metrics/chronos_httpserver
import os
import osproc
import strformat
import strutils
import ../libp2p
@@ -42,6 +43,14 @@ proc baseTest*(scenarioName = "Base test") {.async.} =
if nodeId == 0:
clearSyncFiles()
# --- Collect docker stats for one publishing and one non-publishing node ---
var dockerStatsProc: Process = nil
if nodeId == 0 or nodeId == publisherCount + 1:
let dockerStatsLogPath = getDockerStatsLogPath(scenario, nodeId)
dockerStatsProc = startDockerStatsProcess(nodeId, dockerStatsLogPath)
defer:
dockerStatsProc.stopDockerStatsProcess()
let (switch, gossipSub, pingProtocol) = setupNode(nodeId, rng)
gossipSub.setGossipSubParams()

View File

@@ -0,0 +1,87 @@
import os
import algorithm
import sequtils
import strformat
import strutils
import tables
proc getImgUrlBase(repo: string, publishBranchName: string, plotsPath: string): string =
&"https://raw.githubusercontent.com/{repo}/refs/heads/{publishBranchName}/{plotsPath}"
proc extractTestName(base: string): string =
let parts = base.split("_")
if parts.len >= 2:
parts[^2]
else:
base
proc makeImgTag(imgUrl: string, width: int): string =
&"<img src=\"{imgUrl}\" width=\"{width}\" style=\"margin-right:10px;\" />"
proc prepareLatencyHistoryImage(
imgUrlBase: string, latencyHistoryFilePath: string, width: int = 600
): string =
let latencyImgUrl = &"{imgUrlBase}/{latencyHistoryFilePath}"
makeImgTag(latencyImgUrl, width)
proc prepareDockerStatsImages(
plotDir: string, imgUrlBase: string, branchName: string, width: int = 450
): Table[string, seq[string]] =
## Groups docker stats plot images by test name and returns HTML <img> tags.
var grouped: Table[string, seq[string]]
for path in walkFiles(&"{plotDir}/*.png"):
let plotFile = path.splitPath.tail
let testName = extractTestName(plotFile)
let imgUrl = &"{imgUrlBase}/{branchName}/{plotFile}"
let imgTag = makeImgTag(imgUrl, width)
discard grouped.hasKeyOrPut(testName, @[])
grouped[testName].add(imgTag)
grouped
proc buildSummary(
plotDir: string,
repo: string,
branchName: string,
publishBranchName: string,
plotsPath: string,
latencyHistoryFilePath: string,
): string =
let imgUrlBase = getImgUrlBase(repo, publishBranchName, plotsPath)
var buf: seq[string]
# Latency History section
buf.add("## Latency History")
buf.add(prepareLatencyHistoryImage(imgUrlBase, latencyHistoryFilePath) & "<br>")
buf.add("")
# Performance Plots section
let grouped = prepareDockerStatsImages(plotDir, imgUrlBase, branchName)
buf.add(&"## Performance Plots for {branchName}")
for test in grouped.keys.toSeq().sorted():
let imgs = grouped[test]
buf.add(&"### {test}")
buf.add(imgs.join(" ") & "<br>")
buf.join("\n")
proc main() =
let summaryPath = getEnv("GITHUB_STEP_SUMMARY", "/tmp/step_summary.md")
let repo = getEnv("GITHUB_REPOSITORY", "vacp2p/nim-libp2p")
let branchName = getEnv("BRANCH_NAME", "")
let publishBranchName = getEnv("PUBLISH_BRANCH_NAME", "performance_plots")
let plotsPath = getEnv("PLOTS_PATH", "plots")
let latencyHistoryFilePath =
getEnv("LATENCY_HISTORY_PLOT_FILENAME", "latency_history_all_scenarios.png")
let checkoutSubfolder = getEnv("CHECKOUT_SUBFOLDER", "subplots")
let plotDir = &"{checkoutSubfolder}/{plotsPath}/{branchName}"
let summary = buildSummary(
plotDir, repo, branchName, publishBranchName, plotsPath, latencyHistoryFilePath
)
writeFile(summaryPath, summary)
echo summary
main()

View File

@@ -0,0 +1,126 @@
import os
import glob
import csv
import statistics
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def parse_csv(filepath):
timestamps = []
cpu_percent = []
mem_usage_mb = []
download_MBps = []
upload_MBps = []
download_MB = []
upload_MB = []
with open(filepath, "r") as f:
reader = csv.DictReader(f)
for row in reader:
timestamps.append(float(row["timestamp"]))
cpu_percent.append(float(row["cpu_percent"]))
mem_usage_mb.append(float(row["mem_usage_mb"]))
download_MBps.append(float(row["download_MBps"]))
upload_MBps.append(float(row["upload_MBps"]))
download_MB.append(float(row["download_MB"]))
upload_MB.append(float(row["upload_MB"]))
return {
"timestamps": timestamps,
"cpu_percent": cpu_percent,
"mem_usage_mb": mem_usage_mb,
"download_MBps": download_MBps,
"upload_MBps": upload_MBps,
"download_MB": download_MB,
"upload_MB": upload_MB,
}
def plot_metrics(data, title, output_path):
timestamps = data["timestamps"]
time_points = [t - timestamps[0] for t in timestamps]
cpu = data["cpu_percent"]
mem = data["mem_usage_mb"]
download_MBps = data["download_MBps"]
upload_MBps = data["upload_MBps"]
download_MB = data["download_MB"]
upload_MB = data["upload_MB"]
cpu_median = statistics.median(cpu)
cpu_max = max(cpu)
mem_median = statistics.median(mem)
mem_max = max(mem)
download_MBps_median = statistics.median(download_MBps)
download_MBps_max = max(download_MBps)
upload_MBps_median = statistics.median(upload_MBps)
upload_MBps_max = max(upload_MBps)
download_MB_total = download_MB[-1]
upload_MB_total = upload_MB[-1]
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
fig.suptitle(title, fontsize=16)
# CPU Usage
ax1.plot(time_points, cpu, "b-", label=f"CPU Usage (%)\nmedian = {cpu_median:.2f}\nmax = {cpu_max:.2f}")
ax1.set_ylabel("CPU Usage (%)")
ax1.set_title("CPU Usage Over Time")
ax1.grid(True)
ax1.set_xlim(left=0)
ax1.set_ylim(bottom=0)
ax1.legend(loc="best")
# Memory Usage
ax2.plot(time_points, mem, "m-", label=f"Memory Usage (MB)\nmedian = {mem_median:.2f} MB\nmax = {mem_max:.2f} MB")
ax2.set_ylabel("Memory Usage (MB)")
ax2.set_title("Memory Usage Over Time")
ax2.grid(True)
ax2.set_xlim(left=0)
ax2.set_ylim(bottom=0)
ax2.legend(loc="best")
# Network Throughput
ax3.plot(
time_points,
download_MBps,
"c-",
label=f"Download (MB/s)\nmedian = {download_MBps_median:.2f} MB/s\nmax = {download_MBps_max:.2f} MB/s",
linewidth=2,
)
ax3.plot(
time_points, upload_MBps, "r-", label=f"Upload (MB/s)\nmedian = {upload_MBps_median:.2f} MB/s\nmax = {upload_MBps_max:.2f} MB/s", linewidth=2
)
ax3.set_ylabel("Network Throughput (MB/s)")
ax3.set_title("Network Activity Over Time")
ax3.grid(True)
ax3.set_xlim(left=0)
ax3.set_ylim(bottom=0)
ax3.legend(loc="best", labelspacing=2)
# Accumulated Network Data
ax4.plot(time_points, download_MB, "c-", label=f"Download (MB), total: {download_MB_total:.2f} MB", linewidth=2)
ax4.plot(time_points, upload_MB, "r-", label=f"Upload (MB), total: {upload_MB_total:.2f} MB", linewidth=2)
ax4.set_xlabel("Time (seconds)")
ax4.set_ylabel("Total Data Transferred (MB)")
ax4.set_title("Accumulated Network Data Over Time")
ax4.grid(True)
ax4.set_xlim(left=0)
ax4.set_ylim(bottom=0)
ax4.legend(loc="best")
plt.tight_layout(rect=(0, 0, 1, 1))
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
plt.savefig(output_path, dpi=100, bbox_inches="tight")
plt.close(fig)
print(f"Saved plot to {output_path}")
if __name__ == "__main__":
shared_volume_path = os.environ.get("SHARED_VOLUME_PATH", "performance/output")
docker_stats_prefix = os.environ.get("DOCKER_STATS_PREFIX", "docker_stats_")
glob_pattern = os.path.join(shared_volume_path, f"{docker_stats_prefix}*.csv")
csv_files = glob.glob(glob_pattern)
for csv_file in csv_files:
file_name = os.path.splitext(os.path.basename(csv_file))[0]
data = parse_csv(csv_file)
plot_metrics(data, title=file_name, output_path=os.path.join(shared_volume_path, f"{file_name}.png"))

View File

@@ -0,0 +1,99 @@
import os
import glob
import csv
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def extract_pr_number(filename):
"""Extract PR number from filename of format pr{number}_anything.csv"""
fname = os.path.basename(filename)
parts = fname.split("_", 1)
pr_str = parts[0][2:]
if not pr_str.isdigit():
return None
return int(pr_str)
def parse_latency_csv(csv_files):
pr_numbers = []
scenario_data = {} # scenario -> {pr_num: {min, avg, max}}
for csv_file in csv_files:
pr_num = extract_pr_number(csv_file)
if pr_num is None:
continue
pr_numbers.append(pr_num)
with open(csv_file, newline="") as f:
reader = csv.DictReader(f)
for row in reader:
scenario = row["Scenario"]
if scenario not in scenario_data:
scenario_data[scenario] = {}
scenario_data[scenario][pr_num] = {
"min": float(row["MinLatencyMs"]),
"avg": float(row["AvgLatencyMs"]),
"max": float(row["MaxLatencyMs"]),
}
pr_numbers = sorted(set(pr_numbers))
return pr_numbers, scenario_data
def plot_latency_history(pr_numbers, scenario_data, output_path):
if not pr_numbers or not scenario_data:
print("No PR latency data found; skipping plot generation.")
return
num_scenarios = len(scenario_data)
fig, axes = plt.subplots(num_scenarios, 1, figsize=(14, 4 * num_scenarios), sharex=True)
if num_scenarios == 1:
axes = [axes]
color_map = plt.colormaps.get_cmap("tab10")
x_positions = list(range(len(pr_numbers)))
for i, (scenario, pr_stats) in enumerate(scenario_data.items()):
ax = axes[i]
min_vals = [pr_stats.get(pr, {"min": None})["min"] for pr in pr_numbers]
avg_vals = [pr_stats.get(pr, {"avg": None})["avg"] for pr in pr_numbers]
max_vals = [pr_stats.get(pr, {"max": None})["max"] for pr in pr_numbers]
color = color_map(i % color_map.N)
if any(v is not None for v in avg_vals):
ax.plot(x_positions, avg_vals, marker="o", label="Avg Latency (ms)", color=color)
ax.fill_between(x_positions, min_vals, max_vals, color=color, alpha=0.2, label="Min-Max Latency (ms)")
for x, avg, minv, maxv in zip(x_positions, avg_vals, min_vals, max_vals):
if avg is not None:
ax.scatter(x, avg, color=color)
ax.text(x, avg, f"{avg:.3f}", fontsize=14, ha="center", va="bottom")
if minv is not None and maxv is not None:
ax.vlines(x, minv, maxv, color=color, alpha=0.5)
ax.set_ylabel("Latency (ms)")
ax.set_title(f"Scenario: {scenario}")
ax.legend(loc="upper left", fontsize="small")
ax.grid(True, linestyle="--", alpha=0.5)
# Set X axis ticks and labels to show all PR numbers as 'PR <number>'
axes[-1].set_xlabel("PR Number")
axes[-1].set_xticks(x_positions)
axes[-1].set_xticklabels([f"PR {pr}" for pr in pr_numbers], rotation=45, ha="right", fontsize=14)
plt.tight_layout()
plt.savefig(output_path)
print(f"Saved combined plot to {output_path}")
plt.close(fig)
if __name__ == "__main__":
LATENCY_HISTORY_PATH = os.environ.get("LATENCY_HISTORY_PATH", "performance/output")
LATENCY_HISTORY_PREFIX = os.environ.get("LATENCY_HISTORY_PREFIX", "pr")
LATENCY_HISTORY_PLOT_FILENAME = os.environ.get("LATENCY_HISTORY_PLOT_FILENAME", "pr")
glob_pattern = os.path.join(LATENCY_HISTORY_PATH, f"{LATENCY_HISTORY_PREFIX}[0-9]*_latency.csv")
csv_files = sorted(glob.glob(glob_pattern))
pr_numbers, scenario_data = parse_latency_csv(csv_files)
output_path = os.path.join(LATENCY_HISTORY_PATH, LATENCY_HISTORY_PLOT_FILENAME)
plot_latency_history(pr_numbers, scenario_data, output_path)

View File

@@ -0,0 +1,156 @@
from times import parse, toTime, toUnix
import strformat
import strutils
import json
import os
import options
type DockerStatsSample = object
timestamp: float
cpuPercent: float
memUsageMB: float
netRxMB: float
netTxMB: float
proc parseTimestamp(statsJson: JsonNode): float =
let isoStr = statsJson["read"].getStr("")
let mainPart = isoStr[0 ..< ^1] # remove trailing 'Z'
let parts = mainPart.split(".")
let dt = parse(parts[0], "yyyy-MM-dd'T'HH:mm:ss")
var nanos = 0
if parts.len == 2:
let nsStr = parts[1]
let nsStrPadded = nsStr & repeat('0', 9 - nsStr.len)
nanos = parseInt(nsStrPadded)
let epochNano = dt.toTime.toUnix * 1_000_000_000 + nanos
# Return timestamp in seconds since Unix epoch
return float(epochNano) / 1_000_000_000.0
proc extractCpuRaw(statsJson: JsonNode): (int, int, int) =
let cpuStats = statsJson["cpu_stats"]
let precpuStats = statsJson["precpu_stats"]
let totalUsage = cpuStats["cpu_usage"]["total_usage"].getInt(0)
let prevTotalUsage = precpuStats["cpu_usage"]["total_usage"].getInt(0)
let systemUsage = cpuStats["system_cpu_usage"].getInt(0)
let prevSystemUsage = precpuStats["system_cpu_usage"].getInt(0)
let numCpus = cpuStats["online_cpus"].getInt(0)
return (totalUsage - prevTotalUsage, systemUsage - prevSystemUsage, numCpus)
proc calcCpuPercent(cpuDelta: int, systemDelta: int, numCpus: int): float =
if systemDelta > 0 and cpuDelta > 0 and numCpus > 0:
return (float(cpuDelta) / float(systemDelta)) * float(numCpus) * 100.0
else:
return 0.0
proc extractMemUsageRaw(statsJson: JsonNode): int =
let memStats = statsJson["memory_stats"]
return memStats["usage"].getInt(0)
proc extractNetworkRaw(statsJson: JsonNode): (int, int) =
var netRxBytes = 0
var netTxBytes = 0
if "networks" in statsJson:
for k, v in statsJson["networks"]:
netRxBytes += v["rx_bytes"].getInt(0)
netTxBytes += v["tx_bytes"].getInt(0)
return (netRxBytes, netTxBytes)
proc convertMB(bytes: int): float =
return float(bytes) / 1024.0 / 1024.0
proc parseDockerStatsLine(line: string): Option[DockerStatsSample] =
var samples = none(DockerStatsSample)
if line.len == 0:
return samples
try:
let statsJson = parseJson(line)
let timestamp = parseTimestamp(statsJson)
let (cpuDelta, systemDelta, numCpus) = extractCpuRaw(statsJson)
let cpuPercent = calcCpuPercent(cpuDelta, systemDelta, numCpus)
let memUsageMB = extractMemUsageRaw(statsJson).convertMB()
let (netRxRaw, netTxRaw) = extractNetworkRaw(statsJson)
let netRxMB = netRxRaw.convertMB()
let netTxMB = netTxRaw.convertMB()
return some(
DockerStatsSample(
timestamp: timestamp,
cpuPercent: cpuPercent,
memUsageMB: memUsageMB,
netRxMB: netRxMB,
netTxMB: netTxMB,
)
)
except:
return samples
proc processDockerStatsLog*(inputPath: string): seq[DockerStatsSample] =
var samples: seq[DockerStatsSample]
for line in lines(inputPath):
let sampleOpt = parseDockerStatsLine(line)
if sampleOpt.isSome:
samples.add(sampleOpt.get)
return samples
proc calcRateMBps(curr: float, prev: float, dt: float): float =
if dt == 0:
return 0.0
return ((curr - prev)) / dt
proc writeCsvSeries(samples: seq[DockerStatsSample], outPath: string) =
var f = open(outPath, fmWrite)
f.writeLine(
"timestamp,cpu_percent,mem_usage_mb,download_MBps,upload_MBps,download_MB,upload_MB"
)
if samples.len == 0:
f.close()
return
let timeOffset = samples[0].timestamp
let memOffset = samples[0].memUsageMB
let rxOffset = samples[0].netRxMB
let txOffset = samples[0].netTxMB
var prevRx = samples[0].netRxMB
var prevTx = samples[0].netTxMB
var prevTimestamp = samples[0].timestamp - timeOffset
for s in samples:
let relTimestamp = s.timestamp - timeOffset
let dt = relTimestamp - prevTimestamp
let dlMBps = calcRateMBps(s.netRxMB, prevRx, dt)
let ulMBps = calcRateMBps(s.netTxMB, prevTx, dt)
let dlAcc = s.netRxMB - rxOffset
let ulAcc = s.netTxMB - txOffset
let memUsage = s.memUsageMB - memOffset
f.writeLine(
fmt"{relTimestamp:.2f},{s.cpuPercent:.2f},{memUsage:.2f},{dlMBps:.4f},{ulMBps:.4f},{dlAcc:.4f},{ulAcc:.4f}"
)
prevRx = s.netRxMB
prevTx = s.netTxMB
prevTimestamp = relTimestamp
f.close()
proc findInputFiles(dir: string, prefix: string): seq[string] =
var files: seq[string] = @[]
for entry in walkDir(dir):
if entry.kind == pcFile and entry.path.endsWith(".log") and
entry.path.contains(prefix):
files.add(entry.path)
return files
proc main() =
let dir = getEnv("SHARED_VOLUME_PATH", "performance/output")
let prefix = getEnv("DOCKER_STATS_PREFIX", "docker_stats_")
let inputFiles = findInputFiles(dir, prefix)
if inputFiles.len == 0:
echo "No docker stats files found."
return
for inputFile in inputFiles:
let processedStats = processDockerStatsLog(inputFile)
let outputCsvPath = inputFile.replace(".log", ".csv")
writeCsvSeries(processedStats, outputCsvPath)
echo fmt"Processed stats from {inputFile} written to {outputCsvPath}"
main()

View File

@@ -4,7 +4,7 @@ import sequtils
import strutils
import strformat
import tables
import ./types
import ../types
const unknownFloat = -1.0
@@ -93,7 +93,8 @@ proc getMarkdownReport*(
output.add marker & "\n"
output.add "# 🏁 **Performance Summary**\n"
output.add fmt"**Commit:** `{commitSha}`"
let commitUrl = fmt"https://github.com/vacp2p/nim-libp2p/commit/{commitSha}"
output.add fmt"**Commit:** [`{commitSha}`]({commitUrl})"
output.add "| Scenario | Nodes | Total messages sent | Total messages received | Latency min (ms) | Latency max (ms) | Latency avg (ms) |"
output.add "|:---:|:---:|:---:|:---:|:---:|:---:|:---:|"
@@ -102,10 +103,29 @@ proc getMarkdownReport*(
let nodes = validNodes[scenarioName]
output.add fmt"| {stats.scenarioName} | {nodes} | {stats.totalSent} | {stats.totalReceived} | {stats.latency.minLatencyMs:.3f} | {stats.latency.maxLatencyMs:.3f} | {stats.latency.avgLatencyMs:.3f} |"
let markdown = output.join("\n")
let runId = getEnv("GITHUB_RUN_ID", "")
let summaryUrl = fmt"https://github.com/vacp2p/nim-libp2p/actions/runs/{runId}"
output.add(
fmt"### 📊 View Latency History and full Container Resources in the [Workflow Summary]({summaryUrl})"
)
let markdown = output.join("\n")
return markdown
proc getCsvFilename*(outputDir: string): string =
let prNum = getEnv("PR_NUMBER", "unknown")
result = fmt"{outputDir}/pr{prNum}_latency.csv"
proc getCsvReport*(
results: Table[string, Stats], validNodes: Table[string, int]
): string =
var output: seq[string]
output.add "Scenario,Nodes,TotalSent,TotalReceived,MinLatencyMs,MaxLatencyMs,AvgLatencyMs"
for scenarioName, stats in results.pairs:
let nodes = validNodes[scenarioName]
output.add fmt"{stats.scenarioName},{nodes},{stats.totalSent},{stats.totalReceived},{stats.latency.minLatencyMs:.3f},{stats.latency.maxLatencyMs:.3f},{stats.latency.avgLatencyMs:.3f}"
result = output.join("\n")
proc main() =
let outputDir = "performance/output"
let parsedJsons = parseJsonFiles(outputDir)
@@ -113,6 +133,11 @@ proc main() =
let jsonResults = getJsonResults(parsedJsons)
let (aggregatedResults, validNodes) = aggregateResults(jsonResults)
# For History
let csvFilename = getCsvFilename(outputDir)
let csvContent = getCsvReport(aggregatedResults, validNodes)
writeFile(csvFilename, csvContent)
let marker = getEnv("MARKER", "<!-- marker -->")
let commitSha = getEnv("PR_HEAD_SHA", getEnv("GITHUB_SHA", "unknown"))
let markdown = getMarkdownReport(aggregatedResults, validNodes, marker, commitSha)

View File

@@ -266,7 +266,7 @@ proc syncNodes*(stage: string, nodeId, nodeCount: int) {.async.} =
writeFile(myFile, "ok")
let expectedFiles = (0 ..< nodeCount).mapIt(syncDir / (prefix & stage & "_" & $it))
checkUntilTimeoutCustom(5.seconds, 100.milliseconds):
checkUntilTimeoutCustom(15.seconds, 100.milliseconds):
expectedFiles.allIt(fileExists(it))
# final wait
@@ -279,3 +279,38 @@ proc clearSyncFiles*() =
for f in walkDir(syncDir):
if fileExists(f.path):
removeFile(f.path)
proc getDockerStatsLogPath*(scenarioName: string, nodeId: int): string =
let sanitizedScenario = scenarioName.replace(" ", "").replace("%", "percent")
return &"/output/docker_stats_{sanitizedScenario}_{nodeId}.log"
proc clearDockerStats*(outputPath: string) =
if fileExists(outputPath):
removeFile(outputPath)
proc getContainerId(nodeId: int): string =
let response = execShellCommand(
"curl -s --unix-socket /var/run/docker.sock http://localhost/containers/json"
)
let containers = parseJson(response)
let expectedName = "/node-" & $nodeId
let filtered =
containers.filterIt(it["Names"].getElems(@[]).anyIt(it.getStr("") == expectedName))
if filtered.len == 0:
return ""
return filtered[0]["Id"].getStr("")
proc startDockerStatsProcess*(nodeId: int, outputPath: string): Process =
let containerId = getContainerId(nodeId)
let shellCmd =
fmt"curl --unix-socket /var/run/docker.sock http://localhost/containers/{containerId}/stats > {outputPath} 2>/dev/null"
return startProcess(
"/bin/sh", args = ["-c", shellCmd], options = {poUsePath, poStdErrToStdOut}
)
proc stopDockerStatsProcess*(p: Process) =
if p != nil:
p.kill()
p.close()

View File

@@ -516,6 +516,7 @@ proc relayInteropTests*(name: string, relayCreator: SwitchCreator) =
suite "Interop relay using " & name:
asyncTest "NativeSrc -> NativeRelay -> DaemonDst":
let closeBlocker = newFuture[void]()
let daemonFinished = newFuture[void]()
# TODO: This Future blocks the daemonHandler after sending the last message.
# It exists because there's a strange behavior where stream.close sends
# a Rst instead of Fin. We should investigate this at some point.
@@ -528,6 +529,7 @@ proc relayInteropTests*(name: string, relayCreator: SwitchCreator) =
discard await stream.transp.writeLp("line4")
await closeBlocker
await stream.close()
daemonFinished.complete()
let
maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
@@ -556,8 +558,18 @@ proc relayInteropTests*(name: string, relayCreator: SwitchCreator) =
check string.fromBytes(await conn.readLp(1024)) == "line4"
closeBlocker.complete()
await daemonFinished
await conn.close()
await allFutures(src.stop(), rel.stop())
await daemonNode.close()
try:
await daemonNode.close()
except CatchableError as e:
when defined(windows):
# On Windows, daemon close may fail due to socket race condition
# This is expected behavior and can be safely ignored
discard
else:
raise e
asyncTest "DaemonSrc -> NativeRelay -> NativeDst":
proc customHandler(

View File

@@ -32,6 +32,9 @@ template commonTransportTest*(prov: TransportBuilder, ma1: string, ma2: string =
let conn = await transport1.accept()
if conn.observedAddr.isSome():
check transport1.handles(conn.observedAddr.get())
# skip IP check, only check transport and port
check conn.localAddr.get()[3] == transport1.addrs[0][3]
check conn.localAddr.get()[4] == transport1.addrs[0][4]
await conn.close()
let handlerWait = acceptHandler()

View File

@@ -0,0 +1,3 @@
{.used.}
import testdiscoverymngr, testrendezvous, testrendezvousinterface

View File

@@ -11,25 +11,16 @@
import options, chronos, sets
import
../libp2p/[
../../libp2p/[
protocols/rendezvous,
switch,
builders,
discovery/discoverymngr,
discovery/rendezvousinterface,
]
import ./helpers, ./utils/async_tests
proc createSwitch(rdv: RendezVous = RendezVous.new()): Switch =
SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init(MemoryAutoAddress).tryGet()])
.withMemoryTransport()
.withMplex()
.withNoise()
.withRendezVous(rdv)
.build()
import ../helpers
import ../utils/async_tests
import ./utils
suite "Discovery":
teardown:

View File

@@ -0,0 +1,471 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import sequtils, strutils
import chronos
import ../../libp2p/[protocols/rendezvous, switch]
import ../../libp2p/discovery/discoverymngr
import ../../libp2p/utils/offsettedseq
import ../helpers
import ./utils
suite "RendezVous":
teardown:
checkTrackers()
asyncTest "Request locally returns 0 for empty namespace":
let (nodes, rdvs) = setupNodes(1)
nodes.startAndDeferStop()
const namespace = ""
check rdvs[0].requestLocally(namespace).len == 0
asyncTest "Request locally returns registered peers":
let (nodes, rdvs) = setupNodes(1)
nodes.startAndDeferStop()
const namespace = "foo"
await rdvs[0].advertise(namespace)
let peerRecords = rdvs[0].requestLocally(namespace)
check:
peerRecords.len == 1
peerRecords[0] == nodes[0].peerInfo.signedPeerRecord.data
asyncTest "Unsubscribe Locally removes registered peer":
let (nodes, rdvs) = setupNodes(1)
nodes.startAndDeferStop()
const namespace = "foo"
await rdvs[0].advertise(namespace)
check rdvs[0].requestLocally(namespace).len == 1
rdvs[0].unsubscribeLocally(namespace)
check rdvs[0].requestLocally(namespace).len == 0
asyncTest "Request returns 0 for empty namespace from remote":
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(1)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodes(peerNodes[0], rendezvousNode)
const namespace = "empty"
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
asyncTest "Request returns registered peers from remote":
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(1)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodes(peerNodes[0], rendezvousNode)
const namespace = "foo"
await peerRdvs[0].advertise(namespace)
let peerRecords = await peerRdvs[0].request(Opt.some(namespace))
check:
peerRecords.len == 1
peerRecords[0] == peerNodes[0].peerInfo.signedPeerRecord.data
asyncTest "Unsubscribe removes registered peer from remote":
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(1)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodes(peerNodes[0], rendezvousNode)
const namespace = "foo"
await peerRdvs[0].advertise(namespace)
check (await peerRdvs[0].request(Opt.some(namespace))).len == 1
await peerRdvs[0].unsubscribe(namespace)
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
asyncTest "Consecutive requests with namespace returns peers with pagination":
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(11)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
const namespace = "foo"
await allFutures(peerRdvs.mapIt(it.advertise(namespace)))
var data = peerNodes.mapIt(it.peerInfo.signedPeerRecord.data)
var peerRecords = await peerRdvs[0].request(Opt.some(namespace), 5)
check:
peerRecords.len == 5
peerRecords.allIt(it in data)
data.keepItIf(it notin peerRecords)
peerRecords = await peerRdvs[0].request(Opt.some(namespace))
check:
peerRecords.len == 6
peerRecords.allIt(it in data)
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
asyncTest "Request without namespace returns all registered peers":
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(10)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
const namespaceFoo = "foo"
const namespaceBar = "Bar"
await allFutures(peerRdvs[0 ..< 5].mapIt(it.advertise(namespaceFoo)))
await allFutures(peerRdvs[5 ..< 10].mapIt(it.advertise(namespaceBar)))
check (await peerRdvs[0].request()).len == 10
check (await peerRdvs[0].request(Opt.none(string))).len == 10
asyncTest "Consecutive requests with namespace keep cookie and retun only new peers":
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(2)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
let
rdv0 = peerRdvs[0]
rdv1 = peerRdvs[1]
const namespace = "foo"
await rdv0.advertise(namespace)
discard await rdv0.request(Opt.some(namespace))
await rdv1.advertise(namespace)
let peerRecords = await rdv0.request(Opt.some(namespace))
check:
peerRecords.len == 1
peerRecords[0] == peerNodes[1].peerInfo.signedPeerRecord.data
asyncTest "Request with namespace pagination with multiple namespaces":
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(30)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
let rdv = peerRdvs[0]
# Register peers in different namespaces in mixed order
const
namespaceFoo = "foo"
namespaceBar = "bar"
await allFutures(peerRdvs[0 ..< 5].mapIt(it.advertise(namespaceFoo)))
await allFutures(peerRdvs[5 ..< 10].mapIt(it.advertise(namespaceBar)))
await allFutures(peerRdvs[10 ..< 15].mapIt(it.advertise(namespaceFoo)))
await allFutures(peerRdvs[15 ..< 20].mapIt(it.advertise(namespaceBar)))
var fooRecords = peerNodes[0 ..< 5].concat(peerNodes[10 ..< 15]).mapIt(
it.peerInfo.signedPeerRecord.data
)
var barRecords = peerNodes[5 ..< 10].concat(peerNodes[15 ..< 20]).mapIt(
it.peerInfo.signedPeerRecord.data
)
# Foo Page 1 with limit
var peerRecords = await rdv.request(Opt.some(namespaceFoo), 2)
check:
peerRecords.len == 2
peerRecords.allIt(it in fooRecords)
fooRecords.keepItIf(it notin peerRecords)
# Foo Page 2 with limit
peerRecords = await rdv.request(Opt.some(namespaceFoo), 5)
check:
peerRecords.len == 5
peerRecords.allIt(it in fooRecords)
fooRecords.keepItIf(it notin peerRecords)
# Foo Page 3 with the rest
peerRecords = await rdv.request(Opt.some(namespaceFoo))
check:
peerRecords.len == 3
peerRecords.allIt(it in fooRecords)
fooRecords.keepItIf(it notin peerRecords)
check fooRecords.len == 0
# Foo Page 4 empty
peerRecords = await rdv.request(Opt.some(namespaceFoo))
check peerRecords.len == 0
# Bar Page 1 with all
peerRecords = await rdv.request(Opt.some(namespaceBar), 30)
check:
peerRecords.len == 10
peerRecords.allIt(it in barRecords)
barRecords.keepItIf(it notin peerRecords)
check barRecords.len == 0
# Register new peers
await allFutures(peerRdvs[20 ..< 25].mapIt(it.advertise(namespaceFoo)))
await allFutures(peerRdvs[25 ..< 30].mapIt(it.advertise(namespaceBar)))
# Foo Page 5 only new peers
peerRecords = await rdv.request(Opt.some(namespaceFoo))
check:
peerRecords.len == 5
peerRecords.allIt(
it in peerNodes[20 ..< 25].mapIt(it.peerInfo.signedPeerRecord.data)
)
# Bar Page 2 only new peers
peerRecords = await rdv.request(Opt.some(namespaceBar))
check:
peerRecords.len == 5
peerRecords.allIt(
it in peerNodes[25 ..< 30].mapIt(it.peerInfo.signedPeerRecord.data)
)
# All records
peerRecords = await rdv.request(Opt.none(string))
check peerRecords.len == 30
asyncTest "Request with namespace with expired peers":
let (rendezvousNode, peerNodes, peerRdvs, rdv) =
setupRendezvousNodeWithPeerNodes(20)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
# Advertise peers
const
namespaceFoo = "foo"
namespaceBar = "bar"
await allFutures(peerRdvs[0 ..< 5].mapIt(it.advertise(namespaceFoo)))
await allFutures(peerRdvs[5 ..< 10].mapIt(it.advertise(namespaceBar)))
check:
(await peerRdvs[0].request(Opt.some(namespaceFoo))).len == 5
(await peerRdvs[0].request(Opt.some(namespaceBar))).len == 5
# Overwrite register timeout loop interval
discard rdv.deletesRegister(1.seconds)
# Overwrite expiration times
let now = Moment.now()
for reg in rdv.registered.s.mitems:
reg.expiration = now
# Wait for the deletion
checkUntilTimeout:
rdv.registered.offset == 10
rdv.registered.s.len == 0
(await peerRdvs[0].request(Opt.some(namespaceFoo))).len == 0
(await peerRdvs[0].request(Opt.some(namespaceBar))).len == 0
# Advertise new peers
await allFutures(peerRdvs[10 ..< 15].mapIt(it.advertise(namespaceFoo)))
await allFutures(peerRdvs[15 ..< 20].mapIt(it.advertise(namespaceBar)))
check:
rdv.registered.offset == 10
rdv.registered.s.len == 10
(await peerRdvs[0].request(Opt.some(namespaceFoo))).len == 5
(await peerRdvs[0].request(Opt.some(namespaceBar))).len == 5
asyncTest "Cookie offset is reset to end (returns empty) then new peers are discoverable":
let (rendezvousNode, peerNodes, peerRdvs, rdv) = setupRendezvousNodeWithPeerNodes(3)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
const namespace = "foo"
# Advertise two peers initially
await allFutures(peerRdvs[0 ..< 2].mapIt(it.advertise(namespace)))
# Build and inject overflow cookie: offset past current high()+1
let offset = (rdv.registered.high + 1000).uint64
let cookie = buildProtobufCookie(offset, namespace)
peerRdvs[0].injectCookieForPeer(rendezvousNode.peerInfo.peerId, namespace, cookie)
# First request should return empty due to clamping to high()+1
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
# Advertise a new peer, next request should return only the new one
await peerRdvs[2].advertise(namespace)
let peerRecords = await peerRdvs[0].request(Opt.some(namespace))
check:
peerRecords.len == 1
peerRecords[0] == peerNodes[2].peerInfo.signedPeerRecord.data
asyncTest "Cookie offset is reset to low after flush (returns current entries)":
let (rendezvousNode, peerNodes, peerRdvs, rdv) = setupRendezvousNodeWithPeerNodes(8)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
const namespace = "foo"
# Advertise 4 peers in namespace
await allFutures(peerRdvs[0 ..< 4].mapIt(it.advertise(namespace)))
# Expire all and flush to advance registered.offset
discard rdv.deletesRegister(1.seconds)
let now = Moment.now()
for reg in rdv.registered.s.mitems:
reg.expiration = now
checkUntilTimeout:
rdv.registered.s.len == 0
rdv.registered.offset == 4
# Advertise 4 new peers
await allFutures(peerRdvs[4 ..< 8].mapIt(it.advertise(namespace)))
# Build and inject underflow cookie: offset behind current low
let offset = 0'u64
let cookie = buildProtobufCookie(offset, namespace)
peerRdvs[0].injectCookieForPeer(rendezvousNode.peerInfo.peerId, namespace, cookie)
check (await peerRdvs[0].request(Opt.some(namespace))).len == 4
asyncTest "Cookie namespace mismatch resets to low (returns peers despite offset)":
let (rendezvousNode, peerNodes, peerRdvs, rdv) = setupRendezvousNodeWithPeerNodes(3)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
const namespace = "foo"
await allFutures(peerRdvs.mapIt(it.advertise(namespace)))
# Build and inject cookie with wrong namespace
let offset = 10.uint64
let cookie = buildProtobufCookie(offset, "other")
peerRdvs[0].injectCookieForPeer(rendezvousNode.peerInfo.peerId, namespace, cookie)
check (await peerRdvs[0].request(Opt.some(namespace))).len == 3
asyncTest "Peer default TTL is saved when advertised":
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
setupRendezvousNodeWithPeerNodes(1)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodes(peerNodes[0], rendezvousNode)
const namespace = "foo"
let timeBefore = Moment.now()
await peerRdvs[0].advertise(namespace)
let timeAfter = Moment.now()
# expiration within [timeBefore + 2hours, timeAfter + 2hours]
check:
# Peer Node side
peerRdvs[0].registered.s[0].data.ttl.get == MinimumDuration.seconds.uint64
peerRdvs[0].registered.s[0].expiration >= timeBefore + MinimumDuration
peerRdvs[0].registered.s[0].expiration <= timeAfter + MinimumDuration
# Rendezvous Node side
rendezvousRdv.registered.s[0].data.ttl.get == MinimumDuration.seconds.uint64
rendezvousRdv.registered.s[0].expiration >= timeBefore + MinimumDuration
rendezvousRdv.registered.s[0].expiration <= timeAfter + MinimumDuration
asyncTest "Peer TTL is saved when advertised with TTL":
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
setupRendezvousNodeWithPeerNodes(1)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodes(peerNodes[0], rendezvousNode)
const
namespace = "foo"
ttl = 3.hours
let timeBefore = Moment.now()
await peerRdvs[0].advertise(namespace, ttl)
let timeAfter = Moment.now()
# expiration within [timeBefore + ttl, timeAfter + ttl]
check:
# Peer Node side
peerRdvs[0].registered.s[0].data.ttl.get == ttl.seconds.uint64
peerRdvs[0].registered.s[0].expiration >= timeBefore + ttl
peerRdvs[0].registered.s[0].expiration <= timeAfter + ttl
# Rendezvous Node side
rendezvousRdv.registered.s[0].data.ttl.get == ttl.seconds.uint64
rendezvousRdv.registered.s[0].expiration >= timeBefore + ttl
rendezvousRdv.registered.s[0].expiration <= timeAfter + ttl
asyncTest "Peer can reregister to update its TTL before previous TTL expires":
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
setupRendezvousNodeWithPeerNodes(1)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodes(peerNodes[0], rendezvousNode)
const namespace = "foo"
let now = Moment.now()
await peerRdvs[0].advertise(namespace)
check:
# Peer Node side
peerRdvs[0].registered.s.len == 1
peerRdvs[0].registered.s[0].expiration > now
# Rendezvous Node side
rendezvousRdv.registered.s.len == 1
rendezvousRdv.registered.s[0].expiration > now
await peerRdvs[0].advertise(namespace, 5.hours)
check:
# Added 2nd registration
# Updated expiration of the 1st one to the past
# Will be deleted on deletion heartbeat
# Peer Node side
peerRdvs[0].registered.s.len == 2
peerRdvs[0].registered.s[0].expiration < now
# Rendezvous Node side
rendezvousRdv.registered.s.len == 2
rendezvousRdv.registered.s[0].expiration < now
# Returns only one record
check (await peerRdvs[0].request(Opt.some(namespace))).len == 1
asyncTest "Peer registration is ignored if limit of 1000 registrations is reached":
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
setupRendezvousNodeWithPeerNodes(1)
(rendezvousNode & peerNodes).startAndDeferStop()
await connectNodes(peerNodes[0], rendezvousNode)
const namespace = "foo"
let peerRdv = peerRdvs[0]
# Create 999 registrations
await populatePeerRegistrations(
peerRdv, rendezvousRdv, namespace, RegistrationLimitPerPeer - 1
)
# 1000th registration allowed
await peerRdv.advertise(namespace)
check rendezvousRdv.registered.s.len == RegistrationLimitPerPeer
# 1001st registration ignored, limit reached
await peerRdv.advertise(namespace)
check rendezvousRdv.registered.s.len == RegistrationLimitPerPeer
asyncTest "Various local error":
let rdv = RendezVous.new(minDuration = 1.minutes, maxDuration = 72.hours)
expect AdvertiseError:
discard await rdv.request(Opt.some("A".repeat(300)))
expect AdvertiseError:
discard await rdv.request(Opt.some("A"), -1)
expect AdvertiseError:
discard await rdv.request(Opt.some("A"), 3000)
expect AdvertiseError:
await rdv.advertise("A".repeat(300))
expect AdvertiseError:
await rdv.advertise("A", 73.hours)
expect AdvertiseError:
await rdv.advertise("A", 30.seconds)
test "Various config error":
expect RendezVousError:
discard RendezVous.new(minDuration = 30.seconds)
expect RendezVousError:
discard RendezVous.new(maxDuration = 73.hours)
expect RendezVousError:
discard RendezVous.new(minDuration = 15.minutes, maxDuration = 10.minutes)

View File

@@ -9,22 +9,11 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
import sequtils, strutils
import chronos
import ../libp2p/[protocols/rendezvous, switch, builders]
import ../libp2p/discovery/[rendezvousinterface, discoverymngr]
import ./helpers
proc createSwitch(rdv: RendezVous = RendezVous.new()): Switch =
SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
.withTcpTransport()
.withMplex()
.withNoise()
.withRendezVous(rdv)
.build()
import ../../libp2p/[protocols/rendezvous, switch, builders]
import ../../libp2p/discovery/[rendezvousinterface, discoverymngr]
import ../helpers
import ./utils
type
MockRendezVous = ref object of RendezVous
@@ -33,7 +22,9 @@ type
MockErrorRendezVous = ref object of MockRendezVous
method advertise*(self: MockRendezVous, namespace: string, ttl: Duration) {.async.} =
method advertise*(
self: MockRendezVous, namespace: string, ttl: Duration
) {.async: (raises: [CancelledError, AdvertiseError]).} =
if namespace == "ns1":
self.numAdvertiseNs1 += 1
elif namespace == "ns2":
@@ -43,9 +34,9 @@ method advertise*(self: MockRendezVous, namespace: string, ttl: Duration) {.asyn
method advertise*(
self: MockErrorRendezVous, namespace: string, ttl: Duration
) {.async.} =
) {.async: (raises: [CancelledError, AdvertiseError]).} =
await procCall MockRendezVous(self).advertise(namespace, ttl)
raise newException(CatchableError, "MockErrorRendezVous.advertise")
raise newException(AdvertiseError, "MockErrorRendezVous.advertise")
suite "RendezVous Interface":
teardown:

83
tests/discovery/utils.nim Normal file
View File

@@ -0,0 +1,83 @@
import sequtils
import chronos
import ../../libp2p/[protobuf/minprotobuf, protocols/rendezvous, switch, builders]
proc createSwitch*(rdv: RendezVous = RendezVous.new()): Switch =
SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init(MemoryAutoAddress).tryGet()])
.withMemoryTransport()
.withMplex()
.withNoise()
.withRendezVous(rdv)
.build()
proc setupNodes*(count: int): (seq[Switch], seq[RendezVous]) =
doAssert(count > 0, "Count must be greater than 0")
var
nodes: seq[Switch] = @[]
rdvs: seq[RendezVous] = @[]
for x in 0 ..< count:
let rdv = RendezVous.new()
let node = createSwitch(rdv)
nodes.add(node)
rdvs.add(rdv)
return (nodes, rdvs)
proc setupRendezvousNodeWithPeerNodes*(
count: int
): (Switch, seq[Switch], seq[RendezVous], RendezVous) =
let
(nodes, rdvs) = setupNodes(count + 1)
rendezvousNode = nodes[0]
rendezvousRdv = rdvs[0]
peerNodes = nodes[1 ..^ 1]
peerRdvs = rdvs[1 ..^ 1]
return (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv)
template startAndDeferStop*(nodes: seq[Switch]) =
await allFutures(nodes.mapIt(it.start()))
defer:
await allFutures(nodes.mapIt(it.stop()))
proc connectNodes*[T: Switch](dialer: T, target: T) {.async.} =
await dialer.connect(target.peerInfo.peerId, target.peerInfo.addrs)
proc connectNodesToRendezvousNode*[T: Switch](
nodes: seq[T], rendezvousNode: T
) {.async.} =
for node in nodes:
await connectNodes(node, rendezvousNode)
proc buildProtobufCookie*(offset: uint64, namespace: string): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, offset)
pb.write(2, namespace)
pb.finish()
pb.buffer
proc injectCookieForPeer*(
rdv: RendezVous, peerId: PeerId, namespace: string, cookie: seq[byte]
) =
discard rdv.cookiesSaved.hasKeyOrPut(peerId, {namespace: cookie}.toTable())
proc populatePeerRegistrations*(
peerRdv: RendezVous, targetRdv: RendezVous, namespace: string, count: int
) {.async.} =
# Test helper: quickly populate many registrations for a peer.
# We first create a single real registration, then clone that record
# directly into the rendezvous registry to reach the desired count fast.
#
# Notes:
# - Calling advertise() concurrently results in bufferstream defect.
# - Calling advertise() sequentially is too slow for large counts.
await peerRdv.advertise(namespace)
let record = targetRdv.registered.s[0]
for i in 0 ..< count - 1:
targetRdv.registered.s.add(record)

View File

@@ -4,10 +4,12 @@ import std/enumerate
import chronos
import ../../libp2p/[switch, builders]
import ../../libp2p/protocols/kademlia
import ../../libp2p/protocols/kademlia/kademlia
import ../../libp2p/protocols/kademlia/routingtable
import ../../libp2p/protocols/kademlia/keys
import unittest2
import ../utils/async_tests
import ./utils.nim
import ../helpers
proc createSwitch(): Switch =
@@ -31,14 +33,15 @@ proc countBucketEntries(buckets: seq[Bucket], key: Key): uint32 =
suite "KadDHT - FindNode":
teardown:
checkTrackers()
asyncTest "Simple find peer":
asyncTest "Simple find node":
let swarmSize = 3
var switches: seq[Switch]
var kads: seq[KadDHT]
# every node needs a switch, and an assosciated kad mounted to it
for i in 0 ..< swarmSize:
switches.add(createSwitch())
kads.add(KadDHT.new(switches[i]))
kads.add(KadDHT.new(switches[i], PermissiveValidator(), CandSelector()))
switches[i].mount(kads[i])
# Once the the creation/mounting of switches are done, we can start
@@ -79,24 +82,24 @@ suite "KadDHT - FindNode":
)
await switches.mapIt(it.stop()).allFutures()
asyncTest "Relay find peer":
asyncTest "Relay find node":
let parentSwitch = createSwitch()
let parentKad = KadDHT.new(parentSwitch)
let parentKad = KadDHT.new(parentSwitch, PermissiveValidator(), CandSelector())
parentSwitch.mount(parentKad)
await parentSwitch.start()
let broSwitch = createSwitch()
let broKad = KadDHT.new(broSwitch)
let broKad = KadDHT.new(broSwitch, PermissiveValidator(), CandSelector())
broSwitch.mount(broKad)
await broSwitch.start()
let sisSwitch = createSwitch()
let sisKad = KadDHT.new(sisSwitch)
let sisKad = KadDHT.new(sisSwitch, PermissiveValidator(), CandSelector())
sisSwitch.mount(sisKad)
await sisSwitch.start()
let neiceSwitch = createSwitch()
let neiceKad = KadDHT.new(neiceSwitch)
let neiceKad = KadDHT.new(neiceSwitch, PermissiveValidator(), CandSelector())
neiceSwitch.mount(neiceKad)
await neiceSwitch.start()
@@ -142,3 +145,33 @@ suite "KadDHT - FindNode":
await broSwitch.stop()
await sisSwitch.stop()
await neiceSwitch.stop()
asyncTest "Find peer":
let aliceSwitch = createSwitch()
let aliceKad = KadDHT.new(aliceSwitch, PermissiveValidator(), CandSelector())
aliceSwitch.mount(aliceKad)
await aliceSwitch.start()
let bobSwitch = createSwitch()
let bobKad = KadDHT.new(bobSwitch, PermissiveValidator(), CandSelector())
bobSwitch.mount(bobKad)
await bobSwitch.start()
let charlieSwitch = createSwitch()
let charlieKad = KadDHT.new(charlieSwitch, PermissiveValidator(), CandSelector())
charlieSwitch.mount(charlieKad)
await charlieSwitch.start()
await bobKad.bootstrap(@[aliceSwitch.peerInfo])
await charlieKad.bootstrap(@[aliceSwitch.peerInfo])
let peerInfoRes = await bobKad.findPeer(charlieSwitch.peerInfo.peerId)
doAssert peerInfoRes.isOk
doAssert peerInfoRes.get().peerId == charlieSwitch.peerInfo.peerId
let peerInfoRes2 = await bobKad.findPeer(PeerId.random(newRng()).get())
doAssert peerInfoRes2.isErr
await aliceSwitch.stop()
await bobSwitch.stop()
await charlieSwitch.stop()

View File

@@ -0,0 +1,155 @@
{.used.}
import chronicles
import strformat
# import sequtils
import options
import std/[times]
# import std/enumerate
import chronos
import ../../libp2p/[switch, builders]
import ../../libp2p/protocols/kademlia
import ../../libp2p/protocols/kademlia/kademlia
import ../../libp2p/protocols/kademlia/routingtable
import ../../libp2p/protocols/kademlia/keys
import unittest2
import ../utils/async_tests
import ./utils.nim
import std/tables
import ../helpers
proc createSwitch(): Switch =
SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
.withTcpTransport()
.withMplex()
.withNoise()
.build()
proc countBucketEntries(buckets: seq[Bucket], key: Key): uint32 =
var res: uint32 = 0
for b in buckets:
for ent in b.peers:
if ent.nodeId == key:
res += 1
return res
suite "KadDHT - PutVal":
teardown:
checkTrackers()
asyncTest "Simple put":
let switch1 = createSwitch()
let switch2 = createSwitch()
var kad1 = KadDHT.new(switch1, PermissiveValidator(), CandSelector())
var kad2 = KadDHT.new(switch2, PermissiveValidator(), CandSelector())
switch1.mount(kad1)
switch2.mount(kad2)
await allFutures(switch1.start(), switch2.start())
defer:
await allFutures(switch1.stop(), switch2.stop())
await kad2.bootstrap(@[switch1.peerInfo])
discard await kad1.findNode(kad2.rtable.selfId)
discard await kad2.findNode(kad1.rtable.selfId)
doAssert(len(kad1.dataTable.entries) == 0)
doAssert(len(kad2.dataTable.entries) == 0)
let puttedData = kad1.rtable.selfId.getBytes()
let entryKey = EntryKey.init(puttedData)
let entryVal = EntryValue.init(puttedData)
discard await kad2.putValue(entryKey, entryVal, some(1))
let entered1: EntryValue = kad1.dataTable.entries[entryKey].value
let entered2: EntryValue = kad2.dataTable.entries[entryKey].value
var ents = kad1.dataTable.entries
doAssert(entered1.data == entryVal.data, fmt"table: {ents}, putted: {entryVal}")
doAssert(len(kad1.dataTable.entries) == 1)
ents = kad2.dataTable.entries
doAssert(entered2.data == entryVal.data, fmt"table: {ents}, putted: {entryVal}")
doAssert(len(kad2.dataTable.entries) == 1)
asyncTest "Change Validator":
let switch1 = createSwitch()
let switch2 = createSwitch()
var kad1 = KadDHT.new(switch1, RestrictiveValidator(), CandSelector())
var kad2 = KadDHT.new(switch2, RestrictiveValidator(), CandSelector())
switch1.mount(kad1)
switch2.mount(kad2)
await allFutures(switch1.start(), switch2.start())
defer:
await allFutures(switch1.stop(), switch2.stop())
await kad2.bootstrap(@[switch1.peerInfo])
doAssert(len(kad1.dataTable.entries) == 0)
let puttedData = kad1.rtable.selfId.getBytes()
let entryVal = EntryValue.init(puttedData)
let entryKey = EntryKey.init(puttedData)
discard await kad2.putValue(entryKey, entryVal, some(1))
doAssert(len(kad1.dataTable.entries) == 0, fmt"content: {kad1.dataTable.entries}")
kad1.setValidator(PermissiveValidator())
discard await kad2.putValue(entryKey, entryVal, some(1))
doAssert(len(kad1.dataTable.entries) == 0, fmt"{kad1.dataTable.entries}")
kad2.setValidator(PermissiveValidator())
discard await kad2.putValue(entryKey, entryVal, some(1))
doAssert(len(kad1.dataTable.entries) == 1, fmt"{kad1.dataTable.entries}")
asyncTest "Good Time":
let switch1 = createSwitch()
let switch2 = createSwitch()
var kad1 = KadDHT.new(switch1, PermissiveValidator(), CandSelector())
var kad2 = KadDHT.new(switch2, PermissiveValidator(), CandSelector())
switch1.mount(kad1)
switch2.mount(kad2)
await allFutures(switch1.start(), switch2.start())
defer:
await allFutures(switch1.stop(), switch2.stop())
await kad2.bootstrap(@[switch1.peerInfo])
let puttedData = kad1.rtable.selfId.getBytes()
let entryVal = EntryValue.init(puttedData)
let entryKey = EntryKey.init(puttedData)
discard await kad2.putValue(entryKey, entryVal, some(1))
let time: string = kad1.dataTable.entries[entryKey].time.ts
let now = times.now().utc
let parsed = time.parse(initTimeFormat("yyyy-MM-dd'T'HH:mm:ss'Z'"), utc())
# get the diff between the stringified-parsed and the direct "now"
let elapsed = (now - parsed)
doAssert(elapsed < times.initDuration(seconds = 2))
asyncTest "Reselect":
let switch1 = createSwitch()
let switch2 = createSwitch()
var kad1 = KadDHT.new(switch1, PermissiveValidator(), OthersSelector())
var kad2 = KadDHT.new(switch2, PermissiveValidator(), OthersSelector())
switch1.mount(kad1)
switch2.mount(kad2)
await allFutures(switch1.start(), switch2.start())
defer:
await allFutures(switch1.stop(), switch2.stop())
await kad2.bootstrap(@[switch1.peerInfo])
let puttedData = kad1.rtable.selfId.getBytes()
let entryVal = EntryValue.init(puttedData)
let entryKey = EntryKey.init(puttedData)
discard await kad1.putValue(entryKey, entryVal, some(1))
doAssert(len(kad2.dataTable.entries) == 1, fmt"{kad1.dataTable.entries}")
doAssert(kad2.dataTable.entries[entryKey].value.data == entryVal.data)
discard await kad1.putValue(entryKey, EntryValue.init(@[]), some(1))
doAssert(kad2.dataTable.entries[entryKey].value.data == entryVal.data)
kad2.setSelector(CandSelector())
kad1.setSelector(CandSelector())
discard await kad1.putValue(entryKey, EntryValue.init(@[]), some(1))
doAssert(
kad2.dataTable.entries[entryKey].value == EntryValue.init(@[]),
fmt"{kad2.dataTable.entries}",
)

View File

@@ -12,23 +12,24 @@
import unittest
import chronos
import ../../libp2p/crypto/crypto
import ../../libp2p/protocols/kademlia/[routingtable, consts, keys]
import ../../libp2p/protocols/kademlia/[xordistance, routingtable, consts, keys]
import results
proc testKey*(x: byte): Key =
var buf: array[IdLength, byte]
buf[31] = x
return Key(kind: KeyType.Unhashed, data: buf)
return Key(kind: KeyType.Raw, data: @buf)
let rng = crypto.newRng()
suite "routing table":
test "inserts single key in correct bucket":
let selfId = testKey(0)
var rt = RoutingTable.init(selfId)
var rt = RoutingTable.init(selfId, Opt.none(XorDHasher))
let other = testKey(0b10000000)
discard rt.insert(other)
let idx = bucketIndex(selfId, other)
let idx = bucketIndex(selfId, other, Opt.none(XorDHasher))
check:
rt.buckets.len > idx
rt.buckets[idx].peers.len == 1
@@ -36,12 +37,11 @@ suite "routing table":
test "does not insert beyond capacity":
let selfId = testKey(0)
var rt = RoutingTable.init(selfId)
var rt = RoutingTable.init(selfId, Opt.some(noOpHasher))
let targetBucket = 6
for _ in 0 ..< DefaultReplic + 5:
var kid = randomKeyInBucketRange(selfId, targetBucket, rng)
kid.kind = KeyType.Unhashed
# Overriding so we don't use sha for comparing xor distances
kid.kind = KeyType.Raw # Overriding so we don't use sha for comparing xor distances
discard rt.insert(kid)
check targetBucket < rt.buckets.len
@@ -50,7 +50,7 @@ suite "routing table":
test "findClosest returns sorted keys":
let selfId = testKey(0)
var rt = RoutingTable.init(selfId)
var rt = RoutingTable.init(selfId, Opt.some(noOpHasher))
let ids = @[testKey(1), testKey(2), testKey(3), testKey(4), testKey(5)]
for id in ids:
discard rt.insert(id)
@@ -75,9 +75,8 @@ suite "routing table":
let selfId = testKey(0)
let targetBucket = 3
var rid = randomKeyInBucketRange(selfId, targetBucket, rng)
rid.kind = KeyType.Unhashed
# Overriding so we don't use sha for comparing xor distances
let idx = bucketIndex(selfId, rid)
rid.kind = KeyType.Raw # Overriding so we don't use sha for comparing xor distances
let idx = bucketIndex(selfId, rid, Opt.some(noOpHasher))
check:
idx == targetBucket
rid != selfId

27
tests/kademlia/utils.nim Normal file
View File

@@ -0,0 +1,27 @@
{.used.}
import results
import ../../libp2p/protocols/kademlia/kademlia
type PermissiveValidator* = ref object of EntryValidator
method isValid*(self: PermissiveValidator, key: EntryKey, val: EntryValue): bool =
true
type RestrictiveValidator* = ref object of EntryValidator
method isValid(self: RestrictiveValidator, key: EntryKey, val: EntryValue): bool =
false
type CandSelector* = ref object of EntrySelector
method select*(
self: CandSelector, cand: EntryRecord, others: seq[EntryRecord]
): Result[EntryRecord, string] =
return ok(cand)
type OthersSelector* = ref object of EntrySelector
method select*(
self: OthersSelector, cand: EntryRecord, others: seq[EntryRecord]
): Result[EntryRecord, string] =
return
if others.len == 0:
ok(cand)
else:
ok(others[0])

View File

@@ -243,10 +243,14 @@ suite "GossipSub Integration - Control Messages":
# When an IHAVE message is sent
let p1 = n0.getOrCreatePeer(n1.peerInfo.peerId, @[GossipSubCodec_12])
n0.broadcast(@[p1], RPCMsg(control: some(ihaveMessage)), isHighPriority = false)
await waitForHeartbeat()
# Then the peer has the message ID
# Wait until IHAVE response is received
checkUntilTimeout:
receivedIHaves[].len == 1
# Then the peer has exactly one IHAVE message with the correct message ID
check:
receivedIHaves[].len == 1
receivedIHaves[0] == ControlIHave(topicID: topic, messageIDs: @[messageID])
asyncTest "IWANT messages correctly request messages by their IDs":
@@ -281,10 +285,14 @@ suite "GossipSub Integration - Control Messages":
# When an IWANT message is sent
let p1 = n0.getOrCreatePeer(n1.peerInfo.peerId, @[GossipSubCodec_12])
n0.broadcast(@[p1], RPCMsg(control: some(iwantMessage)), isHighPriority = false)
await waitForHeartbeat()
# Then the peer has the message ID
# Wait until IWANT response is received
checkUntilTimeout:
receivedIWants[].len == 1
# Then the peer has exactly one IWANT message with the correct message ID
check:
receivedIWants[].len == 1
receivedIWants[0] == ControlIWant(messageIDs: @[messageID])
asyncTest "IHAVE for message not held by peer triggers IWANT response to sender":
@@ -316,10 +324,14 @@ suite "GossipSub Integration - Control Messages":
# When an IHAVE message is sent from node0
let p1 = n0.getOrCreatePeer(n1.peerInfo.peerId, @[GossipSubCodec_12])
n0.broadcast(@[p1], RPCMsg(control: some(ihaveMessage)), isHighPriority = false)
await waitForHeartbeat()
# Then node0 should receive an IWANT message from node1 (as node1 doesn't have the message)
# Wait until IWANT response is received
checkUntilTimeout:
receivedIWants[].len == 1
# Then node0 should receive exactly one IWANT message from node1
check:
receivedIWants[].len == 1
receivedIWants[0] == ControlIWant(messageIDs: @[messageID])
asyncTest "IDONTWANT":

View File

@@ -334,14 +334,6 @@ suite "GossipSub Integration - Mesh Management":
# When DValues of Node0 are updated back to the initial dValues
node0.parameters.applyDValues(dValues)
# Waiting more than one heartbeat (60ms) and less than pruneBackoff (1s)
await sleepAsync(pruneBackoff.div(2))
check:
node0.mesh.getOrDefault(topic).len == newDValues.get.dHigh.get
# When pruneBackoff period is done
await sleepAsync(pruneBackoff)
# Then on the next heartbeat mesh is rebalanced and peers are regrafted to the initial d value
check:
checkUntilTimeout:
node0.mesh.getOrDefault(topic).len == dValues.get.d.get

View File

@@ -76,33 +76,6 @@ suite "GossipSub Integration - Message Handling":
teardown:
checkTrackers()
asyncTest "Split IWANT replies when individual messages are below maxSize but combined exceed maxSize":
# This test checks if two messages, each below the maxSize, are correctly split when their combined size exceeds maxSize.
# Expected: Both messages should be received.
let (gossip0, gossip1, receivedMessages) = await setupTest()
let messageSize = gossip1.maxMessageSize div 2 + 1
let (iwantMessageIds, sentMessages) =
createMessages(gossip0, gossip1, messageSize, messageSize)
gossip1.broadcast(
gossip1.mesh["foobar"],
RPCMsg(
control: some(
ControlMessage(
ihave: @[ControlIHave(topicID: "foobar", messageIDs: iwantMessageIds)]
)
)
),
isHighPriority = false,
)
checkUntilTimeout:
receivedMessages[] == sentMessages
check receivedMessages[].len == 2
await teardownTest(gossip0, gossip1)
asyncTest "Discard IWANT replies when both messages individually exceed maxSize":
# This test checks if two messages, each exceeding the maxSize, are discarded and not sent.
# Expected: No messages should be received.
@@ -157,41 +130,6 @@ suite "GossipSub Integration - Message Handling":
await teardownTest(gossip0, gossip1)
asyncTest "Split IWANT replies when one message is below maxSize and the other exceeds maxSize":
# This test checks if, when given two messages where one is below maxSize and the other exceeds it, only the smaller message is processed and sent.
# Expected: Only the smaller message should be received.
let (gossip0, gossip1, receivedMessages) = await setupTest()
let maxSize = gossip1.maxMessageSize
let size1 = maxSize div 2
let size2 = maxSize + 10
let (bigIWantMessageIds, sentMessages) =
createMessages(gossip0, gossip1, size1, size2)
gossip1.broadcast(
gossip1.mesh["foobar"],
RPCMsg(
control: some(
ControlMessage(
ihave: @[ControlIHave(topicID: "foobar", messageIDs: bigIWantMessageIds)]
)
)
),
isHighPriority = false,
)
var smallestSet: HashSet[seq[byte]]
let seqs = toSeq(sentMessages)
if seqs[0] < seqs[1]:
smallestSet.incl(seqs[0])
else:
smallestSet.incl(seqs[1])
checkUntilTimeout:
receivedMessages[] == smallestSet
check receivedMessages[].len == 1
await teardownTest(gossip0, gossip1)
asyncTest "messages are not sent back to source or forwarding peer":
let
numberOfNodes = 3

View File

@@ -27,6 +27,28 @@ suite "Message":
check verify(msg)
test "signature with missing key":
let
seqno = 11'u64
seckey = PrivateKey.random(Ed25519, rng[]).get()
pubkey = seckey.getPublicKey().get()
peer = PeerInfo.new(seckey)
check peer.peerId.hasPublicKey() == true
var msg = Message.init(some(peer), @[], "topic", some(seqno), sign = true)
msg.key = @[]
# get the key from fromPeer field (inlined)
check verify(msg)
test "signature without inlined pubkey in peerId":
let
seqno = 11'u64
peer = PeerInfo.new(PrivateKey.random(RSA, rng[]).get())
var msg = Message.init(some(peer), @[], "topic", some(seqno), sign = true)
msg.key = @[]
# shouldn't work since there's no key field
# and the key is not inlined in peerid (too large)
check verify(msg) == false
test "defaultMsgIdProvider success":
let
seqno = 11'u64

View File

@@ -14,7 +14,7 @@
import chronos
import
../../libp2p/[protocols/connectivity/autonat/client, peerid, multiaddress, switch]
from ../../libp2p/protocols/connectivity/autonat/core import
from ../../libp2p/protocols/connectivity/autonat/types import
NetworkReachability, AutonatUnreachableError, AutonatError
type

160
tests/testautonatv2.nim Normal file
View File

@@ -0,0 +1,160 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2025 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/options
import chronos
import
../libp2p/[
switch,
transports/tcptransport,
upgrademngrs/upgrade,
builders,
protocols/connectivity/autonatv2/types,
protocols/connectivity/autonatv2/utils,
],
./helpers
proc checkEncodeDecode[T](msg: T) =
# this would be equivalent of doing the following (e.g. for DialBack)
# check msg == DialBack.decode(msg.encode()).get()
check msg == T.decode(msg.encode()).get()
proc newAutonatV2ServerSwitch(): Switch =
var builder = newStandardSwitchBuilder().withAutonatV2()
return builder.build()
suite "AutonatV2":
teardown:
checkTrackers()
asyncTest "encode/decode messages":
# DialRequest
checkEncodeDecode(
DialRequest(
addrs:
@[
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
MultiAddress.init("/ip4/127.0.0.1/tcp/4041").get(),
],
nonce: 42,
)
)
# DialResponse
checkEncodeDecode(
DialResponse(
status: ResponseStatus.Ok,
addrIdx: Opt.some(1.uint32),
dialStatus: Opt.some(DialStatus.Ok),
)
)
# DialDataRequest
checkEncodeDecode(DialDataRequest(addrIdx: 42, numBytes: 128))
# DialDataResponse
checkEncodeDecode(DialDataResponse(data: @[1'u8, 2, 3, 4, 5]))
# AutonatV2Msg - DialRequest
checkEncodeDecode(
AutonatV2Msg(
msgType: MsgType.DialRequest,
dialReq: DialRequest(
addrs:
@[
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
MultiAddress.init("/ip4/127.0.0.1/tcp/4041").get(),
],
nonce: 42,
),
)
)
# AutonatV2Msg - DialResponse
checkEncodeDecode(
AutonatV2Msg(
msgType: MsgType.DialResponse,
dialResp: DialResponse(
status: ResponseStatus.Ok,
addrIdx: Opt.some(1.uint32),
dialStatus: Opt.some(DialStatus.Ok),
),
)
)
# AutonatV2Msg - DialDataRequest
checkEncodeDecode(
AutonatV2Msg(
msgType: MsgType.DialDataRequest,
dialDataReq: DialDataRequest(addrIdx: 42, numBytes: 128),
)
)
# AutonatV2Msg - DialDataResponse
checkEncodeDecode(
AutonatV2Msg(
msgType: MsgType.DialDataResponse,
dialDataResp: DialDataResponse(data: @[1'u8, 2, 3, 4, 5]),
)
)
# DialBack
checkEncodeDecode(DialBack(nonce: 123456))
# DialBackResponse
checkEncodeDecode(DialBackResponse(status: DialBackStatus.Ok))
asyncTest "asNetworkReachability":
check asNetworkReachability(DialResponse(status: EInternalError)) == Unknown
check asNetworkReachability(DialResponse(status: ERequestRejected)) == Unknown
check asNetworkReachability(DialResponse(status: EDialRefused)) == Unknown
check asNetworkReachability(
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.none(DialStatus))
) == Unknown
check asNetworkReachability(
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(Unused))
) == Unknown
check asNetworkReachability(
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(EDialError))
) == NotReachable
check asNetworkReachability(
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(EDialBackError))
) == NotReachable
check asNetworkReachability(
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(DialStatus.Ok))
) == Reachable
asyncTest "asAutonatV2Response":
let addrs = @[MultiAddress.init("/ip4/127.0.0.1/tcp/4000").get()]
let errorDialResp = DialResponse(
status: ResponseStatus.Ok,
addrIdx: Opt.none(AddrIdx),
dialStatus: Opt.none(DialStatus),
)
check asAutonatV2Response(errorDialResp, addrs) ==
AutonatV2Response(
reachability: Unknown, dialResp: errorDialResp, addrs: Opt.none(MultiAddress)
)
let correctDialResp = DialResponse(
status: ResponseStatus.Ok,
addrIdx: Opt.some(0.AddrIdx),
dialStatus: Opt.some(DialStatus.Ok),
)
check asAutonatV2Response(correctDialResp, addrs) ==
AutonatV2Response(
reachability: Reachable, dialResp: correctDialResp, addrs: Opt.some(addrs[0])
)
asyncTest "Instanciate server":
let serverSwitch = newAutonatV2ServerSwitch()
await serverSwitch.start()
await serverSwitch.stop()

View File

@@ -17,7 +17,7 @@ import ../libp2p/[connmanager, stream/connection, crypto/crypto, muxers/muxer, p
import helpers
proc getMuxer(peerId: PeerId, dir: Direction = Direction.In): Muxer =
return Muxer(connection: Connection.new(peerId, dir, Opt.none(MultiAddress)))
return Muxer(connection: Connection.new(peerId, dir))
type TestMuxer = ref object of Muxer
peerId: PeerId
@@ -25,7 +25,7 @@ type TestMuxer = ref object of Muxer
method newStream*(
m: TestMuxer, name: string = "", lazy: bool = false
): Future[Connection] {.async: (raises: [CancelledError, LPStreamError, MuxerError]).} =
Connection.new(m.peerId, Direction.Out, Opt.none(MultiAddress))
Connection.new(m.peerId, Direction.Out)
suite "Connection Manager":
teardown:
@@ -124,7 +124,7 @@ suite "Connection Manager":
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let muxer = new TestMuxer
let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress))
let connection = Connection.new(peerId, Direction.In)
muxer.peerId = peerId
muxer.connection = connection
@@ -144,7 +144,7 @@ suite "Connection Manager":
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
let muxer = new TestMuxer
let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress))
let connection = Connection.new(peerId, Direction.In)
muxer.peerId = peerId
muxer.connection = connection

View File

@@ -14,7 +14,7 @@ import unittest2
import ../libp2p/protocols/connectivity/dcutr/core as dcore
import ../libp2p/protocols/connectivity/dcutr/[client, server]
from ../libp2p/protocols/connectivity/autonat/core import NetworkReachability
from ../libp2p/protocols/connectivity/autonat/types import NetworkReachability
import ../libp2p/builders
import ../libp2p/utils/future
import ./helpers

55
tests/testipaddr.nim Normal file
View File

@@ -0,0 +1,55 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2025 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/options
import chronos
import ../libp2p/[utils/ipaddr], ./helpers
suite "IpAddr Utils":
teardown:
checkTrackers()
test "ipAddrMatches":
# same ip address
check ipAddrMatches(
MultiAddress.init("/ip4/127.0.0.1/tcp/4041").get(),
@[MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get()],
)
# different ip address
check not ipAddrMatches(
MultiAddress.init("/ip4/127.0.0.2/tcp/4041").get(),
@[MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get()],
)
test "ipSupport":
check ipSupport(@[MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get()]) ==
(true, false)
check ipSupport(@[MultiAddress.init("/ip6/::1/tcp/4040").get()]) == (false, true)
check ipSupport(
@[
MultiAddress.init("/ip6/::1/tcp/4040").get(),
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
]
) == (true, true)
test "isPrivate, isPublic":
check isPrivate("192.168.1.100")
check not isPublic("192.168.1.100")
check isPrivate("10.0.0.25")
check not isPublic("10.0.0.25")
check isPrivate("169.254.12.34")
check not isPublic("169.254.12.34")
check isPrivate("172.31.200.8")
check not isPublic("172.31.200.8")
check not isPrivate("1.1.1.1")
check isPublic("1.1.1.1")
check not isPrivate("185.199.108.153")
check isPublic("185.199.108.153")

View File

@@ -341,6 +341,20 @@ suite "MultiAddress test suite":
MultiAddress.init("/ip4/0.0.0.0").get().protoAddress().get() == address_v4
MultiAddress.init("/ip6/::0").get().protoAddress().get() == address_v6
test "MultiAddress getPart":
let ma = MultiAddress
.init(
"/ip4/0.0.0.0/tcp/0/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/p2p-circuit/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSuNEXT/unix/stdio/"
)
.get()
check:
$ma.getPart(multiCodec("ip4")).get() == "/ip4/0.0.0.0"
$ma.getPart(multiCodec("tcp")).get() == "/tcp/0"
# returns first codec match
$ma.getPart(multiCodec("p2p")).get() ==
"/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC"
ma.getPart(multiCodec("udp")).isErr()
test "MultiAddress getParts":
let ma = MultiAddress
.init(
@@ -421,3 +435,22 @@ suite "MultiAddress test suite":
for item in CrashesVectors:
let res = MultiAddress.init(hexToSeqByte(item))
check res.isErr()
test "areAddrsConsistent":
# same address should be consistent
check areAddrsConsistent(
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
)
# different addresses with same stack should be consistent
check areAddrsConsistent(
MultiAddress.init("/ip4/127.0.0.2/tcp/4041").get(),
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
)
# addresses with different stacks should not be consistent
check not areAddrsConsistent(
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
MultiAddress.init("/ip4/127.0.0.1/udp/4040").get(),
)

View File

@@ -30,11 +30,14 @@ import
import
testnameresolve, testmultistream, testbufferstream, testidentify,
testobservedaddrmanager, testconnmngr, testswitch, testnoise, testpeerinfo,
testpeerstore, testping, testmplex, testrelayv1, testrelayv2, testrendezvous,
testdiscovery, testyamux, testautonat, testautonatservice, testautorelay, testdcutr,
testhpservice, testutility, testhelpers, testwildcardresolverservice, testperf
testpeerstore, testping, testmplex, testrelayv1, testrelayv2, testyamux,
testyamuxheader, testautonat, testautonatservice, testautonatv2, testautorelay,
testdcutr, testhpservice, testutility, testhelpers, testwildcardresolverservice,
testperf
import kademlia/[testencoding, testroutingtable, testfindnode]
import discovery/testdiscovery
import kademlia/[testencoding, testroutingtable, testfindnode, testputval]
when defined(libp2p_autotls_support):
import testautotls

View File

@@ -15,7 +15,7 @@ import
import ./helpers
proc createServerAcceptConn(
server: QuicTransport
server: QuicTransport, isEofExpected: bool = false
): proc(): Future[void] {.
async: (raises: [transport.TransportError, LPStreamError, CancelledError])
.} =
@@ -32,12 +32,19 @@ proc createServerAcceptConn(
continue
let stream = await getStream(QuicSession(conn), Direction.In)
var resp: array[6, byte]
await stream.readExactly(addr resp, 6)
check string.fromBytes(resp) == "client"
defer:
await stream.close()
await stream.write("server")
await stream.close()
try:
var resp: array[6, byte]
await stream.readExactly(addr resp, 6)
check string.fromBytes(resp) == "client"
await stream.write("server")
except LPStreamEOFError as exc:
if isEofExpected:
discard
else:
raise exc
return handler
@@ -139,7 +146,9 @@ suite "Quic transport":
asyncTest "closing session should close all streams":
let server = await createTransport()
asyncSpawn createServerAcceptConn(server)()
# because some clients will not write full message,
# it is expected for server to receive eof
asyncSpawn createServerAcceptConn(server, true)()
defer:
await server.stop()
@@ -165,3 +174,36 @@ suite "Quic transport":
# run multiple clients simultainiously
await allFutures(runClient(), runClient(), runClient())
asyncTest "read/write Lp":
proc serverHandler(
server: QuicTransport
) {.async: (raises: [transport.TransportError, LPStreamError, CancelledError]).} =
while true:
let conn =
try:
await server.accept()
except QuicTransportAcceptStopped:
return # Transport is stopped
if conn == nil:
continue
let stream = await getStream(QuicSession(conn), Direction.In)
check (await stream.readLp(100)) == fromHex("1234")
await stream.writeLp(fromHex("5678"))
await stream.close()
proc runClient(server: QuicTransport) {.async.} =
let client = await createTransport()
let conn = await client.dial("", server.addrs[0])
let stream = await getStream(QuicSession(conn), Direction.Out)
await stream.writeLp(fromHex("1234"))
check (await stream.readLp(100)) == fromHex("5678")
await client.stop()
let server = await createTransport()
asyncSpawn serverHandler(server)
defer:
await server.stop()
await runClient(server)

View File

@@ -1,158 +0,0 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import sequtils, strutils
import chronos
import ../libp2p/[protocols/rendezvous, switch, builders]
import ../libp2p/discovery/discoverymngr
import ./helpers
proc createSwitch(rdv: RendezVous = RendezVous.new()): Switch =
SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
.withTcpTransport()
.withMplex()
.withNoise()
.withRendezVous(rdv)
.build()
suite "RendezVous":
teardown:
checkTrackers()
asyncTest "Simple local test":
let
rdv = RendezVous.new()
s = createSwitch(rdv)
await s.start()
let res0 = rdv.requestLocally("empty")
check res0.len == 0
await rdv.advertise("foo")
let res1 = rdv.requestLocally("foo")
check:
res1.len == 1
res1[0] == s.peerInfo.signedPeerRecord.data
let res2 = rdv.requestLocally("bar")
check res2.len == 0
rdv.unsubscribeLocally("foo")
let res3 = rdv.requestLocally("foo")
check res3.len == 0
await s.stop()
asyncTest "Simple remote test":
let
rdv = RendezVous.new()
client = createSwitch(rdv)
remoteSwitch = createSwitch()
await client.start()
await remoteSwitch.start()
await client.connect(remoteSwitch.peerInfo.peerId, remoteSwitch.peerInfo.addrs)
let res0 = await rdv.request(Opt.some("empty"))
check res0.len == 0
await rdv.advertise("foo")
let res1 = await rdv.request(Opt.some("foo"))
check:
res1.len == 1
res1[0] == client.peerInfo.signedPeerRecord.data
let res2 = await rdv.request(Opt.some("bar"))
check res2.len == 0
await rdv.unsubscribe("foo")
let res3 = await rdv.request(Opt.some("foo"))
check res3.len == 0
await allFutures(client.stop(), remoteSwitch.stop())
asyncTest "Harder remote test":
var
rdvSeq: seq[RendezVous] = @[]
clientSeq: seq[Switch] = @[]
remoteSwitch = createSwitch()
for x in 0 .. 10:
rdvSeq.add(RendezVous.new())
clientSeq.add(createSwitch(rdvSeq[^1]))
await remoteSwitch.start()
await allFutures(clientSeq.mapIt(it.start()))
await allFutures(
clientSeq.mapIt(remoteSwitch.connect(it.peerInfo.peerId, it.peerInfo.addrs))
)
await allFutures(rdvSeq.mapIt(it.advertise("foo")))
var data = clientSeq.mapIt(it.peerInfo.signedPeerRecord.data)
let res1 = await rdvSeq[0].request(Opt.some("foo"), 5)
check res1.len == 5
for d in res1:
check d in data
data.keepItIf(it notin res1)
let res2 = await rdvSeq[0].request(Opt.some("foo"))
check res2.len == 5
for d in res2:
check d in data
let res3 = await rdvSeq[0].request(Opt.some("foo"))
check res3.len == 0
let res4 = await rdvSeq[0].request()
check res4.len == 11
let res5 = await rdvSeq[0].request(Opt.none(string))
check res5.len == 11
await remoteSwitch.stop()
await allFutures(clientSeq.mapIt(it.stop()))
asyncTest "Simple cookie test":
let
rdvA = RendezVous.new()
rdvB = RendezVous.new()
clientA = createSwitch(rdvA)
clientB = createSwitch(rdvB)
remoteSwitch = createSwitch()
await clientA.start()
await clientB.start()
await remoteSwitch.start()
await clientA.connect(remoteSwitch.peerInfo.peerId, remoteSwitch.peerInfo.addrs)
await clientB.connect(remoteSwitch.peerInfo.peerId, remoteSwitch.peerInfo.addrs)
await rdvA.advertise("foo")
let res1 = await rdvA.request(Opt.some("foo"))
await rdvB.advertise("foo")
let res2 = await rdvA.request(Opt.some("foo"))
check:
res2.len == 1
res2[0] == clientB.peerInfo.signedPeerRecord.data
await allFutures(clientA.stop(), clientB.stop(), remoteSwitch.stop())
asyncTest "Various local error":
let
rdv = RendezVous.new(minDuration = 1.minutes, maxDuration = 72.hours)
switch = createSwitch(rdv)
expect AdvertiseError:
discard await rdv.request(Opt.some("A".repeat(300)))
expect AdvertiseError:
discard await rdv.request(Opt.some("A"), -1)
expect AdvertiseError:
discard await rdv.request(Opt.some("A"), 3000)
expect AdvertiseError:
await rdv.advertise("A".repeat(300))
expect AdvertiseError:
await rdv.advertise("A", 73.hours)
expect AdvertiseError:
await rdv.advertise("A", 30.seconds)
test "Various config error":
expect RendezVousError:
discard RendezVous.new(minDuration = 30.seconds)
expect RendezVousError:
discard RendezVous.new(maxDuration = 73.hours)
expect RendezVousError:
discard RendezVous.new(minDuration = 15.minutes, maxDuration = 10.minutes)

View File

@@ -11,7 +11,11 @@
import sugar
import chronos
import ../libp2p/[stream/connection, stream/bridgestream, muxers/yamux/yamux], ./helpers
import ../libp2p/[stream/connection, stream/bridgestream, muxers/yamux/yamux]
import ./helpers
import ./utils/futures
include ../libp2p/muxers/yamux/yamux
proc newBlockerFut(): Future[void] {.async: (raises: [], raw: true).} =
newFuture[void]()
@@ -24,6 +28,8 @@ suite "Yamux":
ws: int = YamuxDefaultWindowSize,
inTo: Duration = 5.minutes,
outTo: Duration = 5.minutes,
startHandlera = true,
startHandlerb = true,
) {.inject.} =
#TODO in a template to avoid threadvar
let
@@ -34,7 +40,14 @@ suite "Yamux":
Yamux.new(conna, windowSize = ws, inTimeout = inTo, outTimeout = outTo)
yamuxb {.inject.} =
Yamux.new(connb, windowSize = ws, inTimeout = inTo, outTimeout = outTo)
(handlera, handlerb) = (yamuxa.handle(), yamuxb.handle())
var
handlera = completedFuture()
handlerb = completedFuture()
if startHandlera:
handlera = yamuxa.handle()
if startHandlerb:
handlerb = yamuxb.handle()
defer:
await allFutures(
@@ -166,8 +179,9 @@ suite "Yamux":
let writerBlocker = newBlockerFut()
var numberOfRead = 0
const newWindow = 20
yamuxb.streamHandler = proc(conn: Connection) {.async: (raises: []).} =
YamuxChannel(conn).setMaxRecvWindow(20)
YamuxChannel(conn).setMaxRecvWindow(newWindow)
try:
var buffer: array[256000, byte]
while (await conn.readOnce(addr buffer[0], 256000)) > 0:
@@ -183,13 +197,14 @@ suite "Yamux":
# Need to exhaust initial window first
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
await streamA.write(newSeq[byte](142))
const extraBytes = 160
await streamA.write(newSeq[byte](extraBytes))
await streamA.close()
await writerBlocker
# 1 for initial exhaustion + (142 / 20) = 9
check numberOfRead == 9
# 1 for initial exhaustion + (160 / 20) = 9
check numberOfRead == 1 + (extraBytes / newWindow).int
asyncTest "Saturate until reset":
mSetup()
@@ -412,3 +427,101 @@ suite "Yamux":
await streamA.writeLp(fromHex("1234"))
await streamA.close()
check (await streamA.readLp(100)) == fromHex("5678")
suite "Frame handling and stream initiation":
asyncTest "Ping Syn responds Ping Ack":
mSetup(startHandlera = false)
let payload: uint32 = 0x12345678'u32
await conna.write(YamuxHeader.ping(MsgFlags.Syn, payload))
let header = await conna.readHeader()
check:
header.msgType == Ping
header.flags == {Ack}
header.length == payload
asyncTest "Go Away Status responds with Go Away":
mSetup(startHandlera = false)
await conna.write(YamuxHeader.goAway(GoAwayStatus.ProtocolError))
let header = await conna.readHeader()
check:
header.msgType == GoAway
header.flags == {}
header.length == GoAwayStatus.NormalTermination.uint32
for testCase in [
YamuxHeader.data(streamId = 1'u32, length = 0, {Syn}),
YamuxHeader.windowUpdate(streamId = 5'u32, delta = 0, {Syn}),
]:
asyncTest "Syn opens stream and sends Ack - " & $testCase:
mSetup(startHandlera = false)
yamuxb.streamHandler = proc(conn: Connection) {.async: (raises: []).} =
try:
await conn.close()
except CancelledError, LPStreamError:
return
await conna.write(testCase)
let ackHeader = await conna.readHeader()
check:
ackHeader.msgType == WindowUpdate
ackHeader.streamId == testCase.streamId
ackHeader.flags == {Ack}
check:
yamuxb.channels.hasKey(testCase.streamId)
yamuxb.channels[testCase.streamId].opened
let finHeader = await conna.readHeader()
check:
finHeader.msgType == Data
finHeader.streamId == testCase.streamId
finHeader.flags == {Fin}
for badHeader in [
# Reserved parity on Data+Syn (even id against responder)
YamuxHeader.data(streamId = 2'u32, length = 0, {Syn}),
# Reserved stream id 0
YamuxHeader.data(streamId = 0'u32, length = 0, {Syn}),
# Reserved parity on WindowUpdate+Syn (even id against responder)
YamuxHeader.windowUpdate(streamId = 4'u32, delta = 0, {Syn}),
]:
asyncTest "Reject invalid/unknown header - " & $badHeader:
mSetup(startHandlera = false)
await conna.write(badHeader)
let header = await conna.readHeader()
check:
header.msgType == GoAway
header.flags == {}
header.length == GoAwayStatus.ProtocolError.uint32
not yamuxb.channels.hasKey(badHeader.streamId)
asyncTest "Flush unknown-stream Data up to budget then ProtocolError when exceeded":
# Cover the flush path: streamId not in channels, no Syn, with a pre-seeded
# flush budget in yamuxb.flushed. First frame should be flushed (no GoAway),
# second frame exceeding the remaining budget should trigger ProtocolError.
mSetup(startHandlera = false)
let streamId = 11'u32
yamuxb.flushed[streamId] = 4 # allow up to 4 bytes to be flushed
# 1) Send a Data frame (no Syn) with length=3 and a 3-byte payload -> should be flushed.
await conna.write(YamuxHeader.data(streamId = streamId, length = 3))
await conna.write(fromHex("010203"))
# 2) Send another Data frame with length=2 (remaining budget is 1) -> exceeds, expect GoAway.
await conna.write(YamuxHeader.data(streamId = streamId, length = 2))
await conna.write(fromHex("0405"))
let header = await conna.readHeader()
check:
header.msgType == GoAway
header.flags == {}
header.length == GoAwayStatus.ProtocolError.uint32

321
tests/testyamuxheader.nim Normal file
View File

@@ -0,0 +1,321 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2025 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import ../libp2p/stream/[bufferstream, lpstream]
import ./helpers
include ../libp2p/muxers/yamux/yamux
proc readBytes(bytes: array[12, byte]): Future[YamuxHeader] {.async.} =
let bs = BufferStream.new()
defer:
await bs.close()
await bs.pushData(@bytes)
return await readHeader(bs)
suite "Yamux Header Tests":
teardown:
checkTrackers()
asyncTest "Data header":
const
streamId = 1
length = 100
flags = {Syn}
let header = YamuxHeader.data(streamId = streamId, length = length, flags)
let dataEncoded = header.encode()
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x length_bytes]
const expected = [byte 0, 0, 0, 1, 0, 0, 0, streamId.byte, 0, 0, 0, length.byte]
check:
dataEncoded == expected
let headerDecoded = await readBytes(dataEncoded)
check:
headerDecoded.version == 0
headerDecoded.msgType == MsgType.Data
headerDecoded.flags == flags
headerDecoded.streamId == streamId
headerDecoded.length == length
asyncTest "Window update":
const
streamId = 5
delta = 1000
flags = {Syn}
let windowUpdateHeader =
YamuxHeader.windowUpdate(streamId = streamId, delta = delta, flags)
let windowEncoded = windowUpdateHeader.encode()
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x delta_bytes]
# delta == 1000 == 0x03E8
const expected = [byte 0, 1, 0, 1, 0, 0, 0, streamId.byte, 0, 0, 0x03, 0xE8]
check:
windowEncoded == expected
let windowDecoded = await readBytes(windowEncoded)
check:
windowDecoded.version == 0
windowDecoded.msgType == MsgType.WindowUpdate
windowDecoded.flags == flags
windowDecoded.streamId == streamId
windowDecoded.length == delta
asyncTest "Ping":
let pingHeader = YamuxHeader.ping(MsgFlags.Syn, 0x12345678'u32)
let pingEncoded = pingHeader.encode()
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x value_bytes]
const expected = [byte 0, 2, 0, 1, 0, 0, 0, 0, 0x12, 0x34, 0x56, 0x78]
check:
pingEncoded == expected
let pingDecoded = await readBytes(pingEncoded)
check:
pingDecoded.version == 0
pingDecoded.msgType == MsgType.Ping
pingDecoded.flags == {Syn}
pingDecoded.streamId == 0
pingDecoded.length == 0x12345678'u32
asyncTest "Go away":
let goAwayHeader = YamuxHeader.goAway(GoAwayStatus.ProtocolError)
let goAwayEncoded = goAwayHeader.encode()
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x error_bytes]
const expected = [byte 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
check:
goAwayEncoded == expected
let goAwayDecoded = await readBytes(goAwayEncoded)
check:
goAwayDecoded.version == 0
goAwayDecoded.msgType == MsgType.GoAway
goAwayDecoded.flags == {}
goAwayDecoded.streamId == 0
goAwayDecoded.length == 1'u32
asyncTest "Error codes":
let encodedNormal = YamuxHeader.goAway(GoAwayStatus.NormalTermination).encode()
let encodedProtocol = YamuxHeader.goAway(GoAwayStatus.ProtocolError).encode()
let encodedInternal = YamuxHeader.goAway(GoAwayStatus.InternalError).encode()
check:
encodedNormal[11] == 0
encodedProtocol[11] == 1
encodedInternal[11] == 2
let decodedNormal = await readBytes(encodedNormal)
let decodedProtocol = await readBytes(encodedProtocol)
let decodedInternal = await readBytes(encodedInternal)
check:
decodedNormal.msgType == MsgType.GoAway
decodedNormal.length == 0'u32
decodedProtocol.msgType == MsgType.GoAway
decodedProtocol.length == 1'u32
decodedInternal.msgType == MsgType.GoAway
decodedInternal.length == 2'u32
asyncTest "Flags":
const
streamId = 1
length = 100
let cases: seq[(set[MsgFlags], uint8)] =
@[
({}, 0'u8),
({Syn}, 1'u8),
({Ack}, 2'u8),
({Syn, Ack}, 3'u8),
({Fin}, 4'u8),
({Syn, Fin}, 5'u8),
({Ack, Fin}, 6'u8),
({Syn, Ack, Fin}, 7'u8),
({Rst}, 8'u8),
({Syn, Rst}, 9'u8),
({Ack, Rst}, 10'u8),
({Syn, Ack, Rst}, 11'u8),
({Fin, Rst}, 12'u8),
({Syn, Fin, Rst}, 13'u8),
({Ack, Fin, Rst}, 14'u8),
({Syn, Ack, Fin, Rst}, 15'u8),
]
for (flags, low) in cases:
let header = YamuxHeader.data(streamId = streamId, length = length, flags)
let encoded = header.encode()
check encoded[2 .. 3] == [byte 0, low]
let decoded = await readBytes(encoded)
check decoded.flags == flags
asyncTest "Boundary conditions":
# Test maximum values
const maxFlags = {Syn, Ack, Fin, Rst}
let maxHeader =
YamuxHeader.data(streamId = uint32.high, length = uint32.high, maxFlags)
let maxEncoded = maxHeader.encode()
const maxExpected = [byte 0, 0, 0, 15, 255, 255, 255, 255, 255, 255, 255, 255]
check:
maxEncoded == maxExpected
let maxDecoded = await readBytes(maxEncoded)
check:
maxDecoded.version == 0
maxDecoded.msgType == MsgType.Data
maxDecoded.flags == maxFlags
maxDecoded.streamId == uint32.high
maxDecoded.length == uint32.high
# Test minimum values
let minHeader = YamuxHeader.data(streamId = 0, length = 0, {})
let minEncoded = minHeader.encode()
const minExpected = [byte 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
check:
minEncoded == minExpected
let minDecoded = await readBytes(minEncoded)
check:
minDecoded.version == 0
minDecoded.msgType == MsgType.Data
minDecoded.flags == {}
minDecoded.streamId == 0
minDecoded.length == 0'u32
asyncTest "Incomplete header should raise LPStreamIncompleteError":
let buff = BufferStream.new()
let valid = YamuxHeader.data(streamId = 7, length = 0, {}).encode()
# Supply only first 10 bytes (<12)
let partial: seq[byte] = @valid[0 .. 9]
await buff.pushData(partial)
# Start the read first so close() can propagate EOF to it
let headerFut = readHeader(buff)
await buff.close()
expect LPStreamIncompleteError:
discard await headerFut
asyncTest "Non-zero version byte is preserved":
let valid = YamuxHeader.data(streamId = 1, length = 100, {Syn}).encode()
var mutated = valid
mutated[0] = 1'u8
let decoded = await readBytes(mutated)
check:
decoded.version == 1
asyncTest "Invalid msgType should raise YamuxError":
let valid = YamuxHeader.data(streamId = 1, length = 0, {}).encode()
var mutated = valid
mutated[1] = 0xFF'u8
expect YamuxError:
discard await readBytes(mutated)
asyncTest "Invalid flags should raise YamuxError":
let valid = YamuxHeader.data(streamId = 1, length = 0, {}).encode()
var mutated = valid
# Set flags to 16 which is outside the allowed 0..15 range
mutated[2] = 0'u8
mutated[3] = 16'u8
expect YamuxError:
discard await readBytes(mutated)
asyncTest "Invalid flags (high byte non-zero) should raise YamuxError":
let valid = YamuxHeader.data(streamId = 1, length = 0, {}).encode()
var mutated = valid
# Set high flags byte to 1, which is outside the allowed 0..15 range
mutated[2] = 1'u8
mutated[3] = 0'u8
expect YamuxError:
discard await readBytes(mutated)
asyncTest "Partial push (6+6 bytes) completes without closing":
const
streamId = 9
length = 42
flags = {Syn}
# Prepare a valid header
let header = YamuxHeader.data(streamId = streamId, length = length, flags)
let bytes = header.encode()
let buff = BufferStream.new()
defer:
await buff.close()
# Push first half (6 bytes)
let first: seq[byte] = @bytes[0 .. 5]
await buff.pushData(first)
# Start read and then push the remaining bytes
let headerFut = readHeader(buff)
let second: seq[byte] = @bytes[6 .. 11]
await buff.pushData(second)
let decoded = await headerFut
check:
decoded.version == 0
decoded.msgType == MsgType.Data
decoded.flags == flags
decoded.streamId == streamId
decoded.length == length
asyncTest "Two headers back-to-back decode sequentially":
let h1 = YamuxHeader.data(streamId = 2, length = 10, {Ack})
let h2 = YamuxHeader.ping(MsgFlags.Syn, 0xABCDEF01'u32)
let b1 = h1.encode()
let b2 = h2.encode()
let buff = BufferStream.new()
defer:
await buff.close()
await buff.pushData(@b1 & @b2)
let d1 = await readHeader(buff)
let d2 = await readHeader(buff)
check:
d1.msgType == MsgType.Data
d1.streamId == 2
d1.length == 10
d1.flags == {Ack}
d2.msgType == MsgType.Ping
d2.streamId == 0
d2.length == 0xABCDEF01'u32
d2.flags == {Syn}
asyncTest "StreamId 0x01020304 encodes big-endian":
const streamId = 0x01020304'u32
let header = YamuxHeader.data(streamId = streamId, length = 0, {})
let enc = header.encode()
check enc[4 .. 7] == [byte 1, 2, 3, 4]
let dec = await readBytes(enc)
check dec.streamId == streamId
asyncTest "GoAway unknown status code is preserved":
let valid = YamuxHeader.goAway(GoAwayStatus.NormalTermination).encode()
var mutated = valid
# Set the GoAway code (last byte) to 255, which is not a known GoAwayStatus
mutated[11] = 255'u8
let decoded = await readBytes(mutated)
check:
decoded.msgType == MsgType.GoAway
decoded.length == 255'u32

View File

@@ -59,3 +59,8 @@ proc waitForStates*[T](
): Future[seq[FutureStateWrapper[T]]] {.async.} =
await sleepAsync(timeout)
return futures.mapIt(it.toState())
proc completedFuture*(): Future[void] =
let f = newFuture[void]()
f.complete()
f