mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2026-01-10 12:58:05 -05:00
Compare commits
46 Commits
pwhite/mix
...
no-splitti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b0565725f | ||
|
|
f345026900 | ||
|
|
4780af2036 | ||
|
|
5d6578a06f | ||
|
|
871a5d047f | ||
|
|
061195195b | ||
|
|
8add5aaaab | ||
|
|
dbf60b74c7 | ||
|
|
d2eaf07960 | ||
|
|
6e5274487e | ||
|
|
7ed62461d7 | ||
|
|
6059ee8332 | ||
|
|
4f7e232a9e | ||
|
|
5eaa43b860 | ||
|
|
17ed2d88df | ||
|
|
c7f29ed5db | ||
|
|
9865cc39b5 | ||
|
|
601f56b786 | ||
|
|
25a8ed4d07 | ||
|
|
955e28ff70 | ||
|
|
f952e6d436 | ||
|
|
bed83880bf | ||
|
|
9bd4b7393f | ||
|
|
12d1fae404 | ||
|
|
17073dc9e0 | ||
|
|
b1649b3566 | ||
|
|
ef20f46b47 | ||
|
|
9161529c84 | ||
|
|
8b70384b6a | ||
|
|
f25814a890 | ||
|
|
3d5ea1fa3c | ||
|
|
2114008704 | ||
|
|
04796b210b | ||
|
|
59faa023aa | ||
|
|
fdebea4e14 | ||
|
|
0c188df806 | ||
|
|
abee5326dc | ||
|
|
71f04d1bb3 | ||
|
|
41ae43ae80 | ||
|
|
5dbf077d9e | ||
|
|
b5fc7582ff | ||
|
|
7f83ebb198 | ||
|
|
ceb89986c1 | ||
|
|
f4ff27ca6b | ||
|
|
b517b692df | ||
|
|
7cfd26035a |
11
.github/actions/add_comment/action.yml
vendored
11
.github/actions/add_comment/action.yml
vendored
@@ -1,12 +1,5 @@
|
||||
name: Add Comment
|
||||
description: "Add or update comment in the PR"
|
||||
inputs:
|
||||
marker:
|
||||
description: "Text used to find the comment to update"
|
||||
required: true
|
||||
markdown_path:
|
||||
description: "Path to the file containing markdown"
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
@@ -16,8 +9,8 @@ runs:
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const marker = "${{ inputs.marker }}";
|
||||
const body = fs.readFileSync("${{ inputs.markdown_path }}", 'utf8');
|
||||
const marker = "${{ env.MARKER }}";
|
||||
const body = fs.readFileSync("${{ env.COMMENT_SUMMARY_PATH }}", 'utf8');
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
|
||||
24
.github/actions/generate_plots/action.yml
vendored
Normal file
24
.github/actions/generate_plots/action.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Generate Plots
|
||||
description: "Set up Python and run script to generate plots with Docker Stats"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install matplotlib
|
||||
|
||||
- name: Plot Docker Stats
|
||||
shell: bash
|
||||
run: python performance/scripts/plot_docker_stats.py
|
||||
|
||||
- name: Plot Latency History
|
||||
shell: bash
|
||||
run: python performance/scripts/plot_latency_history.py
|
||||
21
.github/actions/process_stats/action.yml
vendored
Normal file
21
.github/actions/process_stats/action.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Process Stats
|
||||
description: "Set up Nim and run scripts to aggregate latency and process raw docker stats"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Nim
|
||||
uses: jiro4989/setup-nim-action@v2
|
||||
with:
|
||||
nim-version: "2.x"
|
||||
repo-token: ${{ env.GITHUB_TOKEN }}
|
||||
|
||||
- name: Aggregate latency stats and prepare markdown for comment and summary
|
||||
shell: bash
|
||||
run: |
|
||||
nim c -r -d:release -o:/tmp/process_latency_stats ./performance/scripts/process_latency_stats.nim
|
||||
|
||||
- name: Process raw docker stats to csv files
|
||||
shell: bash
|
||||
run: |
|
||||
nim c -r -d:release -o:/tmp/process_docker_stats ./performance/scripts/process_docker_stats.nim
|
||||
36
.github/actions/publish_history/action.yml
vendored
Normal file
36
.github/actions/publish_history/action.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Publish Latency History
|
||||
description: "Publish latency history CSVs in a configurable branch and folder"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Clone the branch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: ${{ github.repository }}
|
||||
ref: ${{ env.PUBLISH_BRANCH_NAME }}
|
||||
path: ${{ env.CHECKOUT_SUBFOLDER_HISTORY }}
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Commit & push latency history CSVs
|
||||
shell: bash
|
||||
run: |
|
||||
cd "$CHECKOUT_SUBFOLDER_HISTORY"
|
||||
git fetch origin "$PUBLISH_BRANCH_NAME"
|
||||
git reset --hard "origin/$PUBLISH_BRANCH_NAME"
|
||||
|
||||
mkdir -p "$PUBLISH_DIR_LATENCY_HISTORY"
|
||||
|
||||
cp ../$SHARED_VOLUME_PATH/$LATENCY_HISTORY_PREFIX*.csv "$PUBLISH_DIR_LATENCY_HISTORY/"
|
||||
git add "$PUBLISH_DIR_LATENCY_HISTORY"
|
||||
|
||||
if git diff-index --quiet HEAD --; then
|
||||
echo "No changes to commit"
|
||||
else
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config user.name "github-actions[bot]"
|
||||
git commit -m "Update latency history CSVs"
|
||||
git push origin "$PUBLISH_BRANCH_NAME"
|
||||
fi
|
||||
|
||||
cd ..
|
||||
56
.github/actions/publish_plots/action.yml
vendored
Normal file
56
.github/actions/publish_plots/action.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Publish Plots
|
||||
description: "Publish plots in performance_plots branch and add to the workflow summary"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Clone the performance_plots branch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: ${{ github.repository }}
|
||||
ref: ${{ env.PUBLISH_BRANCH_NAME }}
|
||||
path: ${{ env.CHECKOUT_SUBFOLDER_SUBPLOTS }}
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Commit & push plots
|
||||
shell: bash
|
||||
run: |
|
||||
cd $CHECKOUT_SUBFOLDER_SUBPLOTS
|
||||
git fetch origin "$PUBLISH_BRANCH_NAME"
|
||||
git reset --hard "origin/$PUBLISH_BRANCH_NAME"
|
||||
|
||||
# Remove any branch folder older than 7 days
|
||||
DAYS=7
|
||||
cutoff=$(( $(date +%s) - DAYS*24*3600 ))
|
||||
scan_dir="${PUBLISH_DIR_PLOTS%/}"
|
||||
find "$scan_dir" -mindepth 1 -maxdepth 1 -type d -print0 \
|
||||
| while IFS= read -r -d $'\0' d; do \
|
||||
ts=$(git log -1 --format=%ct -- "$d" 2>/dev/null || true); \
|
||||
if [ -n "$ts" ] && [ "$ts" -le "$cutoff" ]; then \
|
||||
echo "[cleanup] Deleting: $d"; \
|
||||
rm -rf -- "$d"; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
rm -rf $PUBLISH_DIR_PLOTS/$BRANCH_NAME
|
||||
mkdir -p $PUBLISH_DIR_PLOTS/$BRANCH_NAME
|
||||
|
||||
cp ../$SHARED_VOLUME_PATH/*.png $PUBLISH_DIR_PLOTS/$BRANCH_NAME/ 2>/dev/null || true
|
||||
cp ../$LATENCY_HISTORY_PATH/*.png $PUBLISH_DIR_PLOTS/ 2>/dev/null || true
|
||||
git add -A "$PUBLISH_DIR_PLOTS/"
|
||||
|
||||
git status
|
||||
|
||||
if git diff-index --quiet HEAD --; then
|
||||
echo "No changes to commit"
|
||||
else
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config user.name "github-actions[bot]"
|
||||
git commit -m "Update performance plots for $BRANCH_NAME"
|
||||
git push origin $PUBLISH_BRANCH_NAME
|
||||
fi
|
||||
|
||||
- name: Add plots to GitHub Actions summary
|
||||
shell: bash
|
||||
run: |
|
||||
nim c -r -d:release -o:/tmp/add_plots_to_summary ./performance/scripts/add_plots_to_summary.nim
|
||||
6
.github/workflows/ci.yml
vendored
6
.github/workflows/ci.yml
vendored
@@ -25,8 +25,6 @@ jobs:
|
||||
cpu: i386
|
||||
- os: linux-gcc-14
|
||||
cpu: amd64
|
||||
- os: macos
|
||||
cpu: amd64
|
||||
- os: macos-14
|
||||
cpu: arm64
|
||||
- os: windows
|
||||
@@ -45,10 +43,6 @@ jobs:
|
||||
os: linux-gcc-14
|
||||
builder: ubuntu-24.04
|
||||
shell: bash
|
||||
- platform:
|
||||
os: macos
|
||||
builder: macos-13
|
||||
shell: bash
|
||||
- platform:
|
||||
os: macos-14
|
||||
builder: macos-14
|
||||
|
||||
2
.github/workflows/linters.yml
vendored
2
.github/workflows/linters.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
fetch-depth: 2 # In PR, has extra merge commit: ^1 = PR, ^2 = base
|
||||
|
||||
- name: Check `nph` formatting
|
||||
uses: arnetheduck/nph-action@v1
|
||||
uses: arnetheduck/nph-action@ef5e9fae6dbaf88ec4308cbede780a0ba45f845d
|
||||
with:
|
||||
version: 0.6.1
|
||||
options: "examples libp2p tests interop tools *.nim*"
|
||||
|
||||
59
.github/workflows/performance.yml
vendored
59
.github/workflows/performance.yml
vendored
@@ -13,8 +13,8 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
examples:
|
||||
timeout-minutes: 10
|
||||
performance:
|
||||
timeout-minutes: 20
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -22,6 +22,25 @@ jobs:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
VACP2P: "vacp2p"
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
PR_NUMBER: ${{ github.event.number }}
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
MARKER: "<!-- perf-summary-marker -->"
|
||||
COMMENT_SUMMARY_PATH: "/tmp/perf-summary.md"
|
||||
SHARED_VOLUME_PATH: "performance/output"
|
||||
DOCKER_STATS_PREFIX: "docker_stats_"
|
||||
PUBLISH_BRANCH_NAME: "performance_plots"
|
||||
CHECKOUT_SUBFOLDER_SUBPLOTS: "subplots"
|
||||
PUBLISH_DIR_PLOTS: "plots"
|
||||
CHECKOUT_SUBFOLDER_HISTORY: "history"
|
||||
PUBLISH_DIR_LATENCY_HISTORY: "latency_history"
|
||||
LATENCY_HISTORY_PATH: "history/latency_history"
|
||||
LATENCY_HISTORY_PREFIX: "pr"
|
||||
LATENCY_HISTORY_PLOT_FILENAME: "latency_history_all_scenarios.png"
|
||||
|
||||
name: "Performance"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -47,23 +66,29 @@ jobs:
|
||||
run: |
|
||||
./performance/runner.sh
|
||||
|
||||
- name: Set up Nim for aggragate script
|
||||
uses: jiro4989/setup-nim-action@v2
|
||||
with:
|
||||
nim-version: "2.x"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Process latency and docker stats
|
||||
uses: ./.github/actions/process_stats
|
||||
|
||||
- name: Aggregate and display summary
|
||||
env:
|
||||
MARKER: "<!-- perf-summary-marker -->"
|
||||
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
COMMENT_SUMMARY_PATH: "/tmp/perf-summary.md"
|
||||
run: |
|
||||
nim c -r -d:release -o:/tmp/aggregate_stats ./performance/aggregate_stats.nim
|
||||
- name: Publish history
|
||||
if: github.repository_owner == env.VACP2P
|
||||
uses: ./.github/actions/publish_history
|
||||
|
||||
- name: Post/Update PR Performance Comment
|
||||
- name: Generate plots
|
||||
if: github.repository_owner == env.VACP2P
|
||||
uses: ./.github/actions/generate_plots
|
||||
|
||||
- name: Post/Update PR comment
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: ./.github/actions/add_comment
|
||||
|
||||
- name: Upload performance artifacts
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
marker: "<!-- perf-summary-marker -->"
|
||||
markdown_path: "/tmp/perf-summary.md"
|
||||
name: performance-artifacts
|
||||
path: |
|
||||
performance/output/pr*_latency.csv
|
||||
performance/output/*.png
|
||||
history/latency_history/*.png
|
||||
if-no-files-found: ignore
|
||||
retention-days: 7
|
||||
|
||||
2
.pinned
2
.pinned
@@ -8,7 +8,7 @@ json_serialization;https://github.com/status-im/nim-json-serialization@#2b1c5eb1
|
||||
metrics;https://github.com/status-im/nim-metrics@#6142e433fc8ea9b73379770a788017ac528d46ff
|
||||
ngtcp2;https://github.com/status-im/nim-ngtcp2@#9456daa178c655bccd4a3c78ad3b8cce1f0add73
|
||||
nimcrypto;https://github.com/cheatfate/nimcrypto@#19c41d6be4c00b4a2c8000583bd30cf8ceb5f4b1
|
||||
quic;https://github.com/status-im/nim-quic.git@#d9a4cbccd509f7a3ee835f75b01dec29d27a0f14
|
||||
quic;https://github.com/vacp2p/nim-quic@#cae13c2d22ba2730c979486cf89b88927045c3ae
|
||||
results;https://github.com/arnetheduck/nim-results@#df8113dda4c2d74d460a8fa98252b0b771bf1f27
|
||||
secp256k1;https://github.com/status-im/nim-secp256k1@#f808ed5e7a7bfc42204ec7830f14b7a42b63c284
|
||||
serialization;https://github.com/status-im/nim-serialization@#548d0adc9797a10b2db7f788b804330306293088
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
mode = ScriptMode.Verbose
|
||||
|
||||
packageName = "libp2p"
|
||||
version = "1.11.0"
|
||||
version = "1.12.0"
|
||||
author = "Status Research & Development GmbH"
|
||||
description = "LibP2P implementation"
|
||||
license = "MIT"
|
||||
@@ -10,7 +10,7 @@ skipDirs = @["tests", "examples", "Nim", "tools", "scripts", "docs"]
|
||||
requires "nim >= 2.0.0",
|
||||
"nimcrypto >= 0.6.0 & < 0.7.0", "dnsclient >= 0.3.0 & < 0.4.0", "bearssl >= 0.2.5",
|
||||
"chronicles >= 0.11.0 & < 0.12.0", "chronos >= 4.0.4", "metrics", "secp256k1",
|
||||
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.10",
|
||||
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.15",
|
||||
"https://github.com/vacp2p/nim-jwt.git#18f8378de52b241f321c1f9ea905456e89b95c6f"
|
||||
|
||||
let nimc = getEnv("NIMC", "nim") # Which nim compiler to use
|
||||
|
||||
@@ -85,6 +85,7 @@ when defined(libp2p_autotls_support):
|
||||
../crypto/rsa,
|
||||
../utils/heartbeat,
|
||||
../transports/transport,
|
||||
../utils/ipaddr,
|
||||
../transports/tcptransport,
|
||||
../nameresolving/dnsresolver
|
||||
|
||||
@@ -150,7 +151,10 @@ when defined(libp2p_autotls_support):
|
||||
if self.config.ipAddress.isNone():
|
||||
try:
|
||||
self.config.ipAddress = Opt.some(getPublicIPAddress())
|
||||
except AutoTLSError as exc:
|
||||
except ValueError as exc:
|
||||
error "Failed to get public IP address", err = exc.msg
|
||||
return false
|
||||
except OSError as exc:
|
||||
error "Failed to get public IP address", err = exc.msg
|
||||
return false
|
||||
self.managerFut = self.run(switch)
|
||||
|
||||
@@ -22,7 +22,7 @@ const
|
||||
type AutoTLSError* = object of LPError
|
||||
|
||||
when defined(libp2p_autotls_support):
|
||||
import net, strutils
|
||||
import strutils
|
||||
from times import DateTime, toTime, toUnix
|
||||
import stew/base36
|
||||
import
|
||||
@@ -33,36 +33,6 @@ when defined(libp2p_autotls_support):
|
||||
../nameresolving/nameresolver,
|
||||
./acme/client
|
||||
|
||||
proc checkedGetPrimaryIPAddr*(): IpAddress {.raises: [AutoTLSError].} =
|
||||
# This is so that we don't need to catch Exceptions directly
|
||||
# since we support 1.6.16 and getPrimaryIPAddr before nim 2 didn't have explicit .raises. pragmas
|
||||
try:
|
||||
return getPrimaryIPAddr()
|
||||
except Exception as exc:
|
||||
raise newException(AutoTLSError, "Error while getting primary IP address", exc)
|
||||
|
||||
proc isIPv4*(ip: IpAddress): bool =
|
||||
ip.family == IpAddressFamily.IPv4
|
||||
|
||||
proc isPublic*(ip: IpAddress): bool {.raises: [AutoTLSError].} =
|
||||
let ip = $ip
|
||||
try:
|
||||
not (
|
||||
ip.startsWith("10.") or
|
||||
(ip.startsWith("172.") and parseInt(ip.split(".")[1]) in 16 .. 31) or
|
||||
ip.startsWith("192.168.") or ip.startsWith("127.") or ip.startsWith("169.254.")
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise newException(AutoTLSError, "Failed to parse IP address", exc)
|
||||
|
||||
proc getPublicIPAddress*(): IpAddress {.raises: [AutoTLSError].} =
|
||||
let ip = checkedGetPrimaryIPAddr()
|
||||
if not ip.isIPv4():
|
||||
raise newException(AutoTLSError, "Host does not have an IPv4 address")
|
||||
if not ip.isPublic():
|
||||
raise newException(AutoTLSError, "Host does not have a public IPv4 address")
|
||||
return ip
|
||||
|
||||
proc asMoment*(dt: DateTime): Moment =
|
||||
let unixTime: int64 = dt.toTime.toUnix
|
||||
return Moment.init(unixTime, Second)
|
||||
|
||||
@@ -26,7 +26,8 @@ import
|
||||
transports/[transport, tcptransport, wstransport, memorytransport],
|
||||
muxers/[muxer, mplex/mplex, yamux/yamux],
|
||||
protocols/[identify, secure/secure, secure/noise, rendezvous],
|
||||
protocols/connectivity/[autonat/server, relay/relay, relay/client, relay/rtransport],
|
||||
protocols/connectivity/
|
||||
[autonat/server, autonatv2/server, relay/relay, relay/client, relay/rtransport],
|
||||
connmanager,
|
||||
upgrademngrs/muxedupgrade,
|
||||
observedaddrmanager,
|
||||
@@ -74,6 +75,7 @@ type
|
||||
nameResolver: NameResolver
|
||||
peerStoreCapacity: Opt[int]
|
||||
autonat: bool
|
||||
autonatV2: bool
|
||||
autotls: AutotlsService
|
||||
circuitRelay: Relay
|
||||
rdv: RendezVous
|
||||
@@ -280,6 +282,10 @@ proc withAutonat*(b: SwitchBuilder): SwitchBuilder =
|
||||
b.autonat = true
|
||||
b
|
||||
|
||||
proc withAutonatV2*(b: SwitchBuilder): SwitchBuilder =
|
||||
b.autonatV2 = true
|
||||
b
|
||||
|
||||
when defined(libp2p_autotls_support):
|
||||
proc withAutotls*(
|
||||
b: SwitchBuilder, config: AutotlsConfig = AutotlsConfig.new()
|
||||
@@ -379,7 +385,10 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
|
||||
|
||||
switch.mount(identify)
|
||||
|
||||
if b.autonat:
|
||||
if b.autonatV2:
|
||||
let autonatV2 = AutonatV2.new(switch)
|
||||
switch.mount(autonatV2)
|
||||
elif b.autonat:
|
||||
let autonat = Autonat.new(switch)
|
||||
switch.mount(autonat)
|
||||
|
||||
@@ -395,7 +404,7 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
|
||||
|
||||
return switch
|
||||
|
||||
proc newStandardSwitch*(
|
||||
proc newStandardSwitchBuilder*(
|
||||
privKey = none(PrivateKey),
|
||||
addrs: MultiAddress | seq[MultiAddress] =
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
|
||||
@@ -411,7 +420,7 @@ proc newStandardSwitch*(
|
||||
nameResolver: NameResolver = nil,
|
||||
sendSignedPeerRecord = false,
|
||||
peerStoreCapacity = 1000,
|
||||
): Switch {.raises: [LPError], public.} =
|
||||
): SwitchBuilder {.raises: [LPError], public.} =
|
||||
## Helper for common switch configurations.
|
||||
let addrs =
|
||||
when addrs is MultiAddress:
|
||||
@@ -436,4 +445,39 @@ proc newStandardSwitch*(
|
||||
privKey.withValue(pkey):
|
||||
b = b.withPrivateKey(pkey)
|
||||
|
||||
b.build()
|
||||
b
|
||||
|
||||
proc newStandardSwitch*(
|
||||
privKey = none(PrivateKey),
|
||||
addrs: MultiAddress | seq[MultiAddress] =
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
|
||||
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
|
||||
transportFlags: set[ServerFlags] = {},
|
||||
rng = newRng(),
|
||||
inTimeout: Duration = 5.minutes,
|
||||
outTimeout: Duration = 5.minutes,
|
||||
maxConnections = MaxConnections,
|
||||
maxIn = -1,
|
||||
maxOut = -1,
|
||||
maxConnsPerPeer = MaxConnectionsPerPeer,
|
||||
nameResolver: NameResolver = nil,
|
||||
sendSignedPeerRecord = false,
|
||||
peerStoreCapacity = 1000,
|
||||
): Switch {.raises: [LPError], public.} =
|
||||
newStandardSwitchBuilder(
|
||||
privKey = privKey,
|
||||
addrs = addrs,
|
||||
secureManagers = secureManagers,
|
||||
transportFlags = transportFlags,
|
||||
rng = rng,
|
||||
inTimeout = inTimeout,
|
||||
outTimeout = outTimeout,
|
||||
maxConnections = maxConnections,
|
||||
maxIn = maxIn,
|
||||
maxOut = maxOut,
|
||||
maxConnsPerPeer = maxConnsPerPeer,
|
||||
nameResolver = nameResolver,
|
||||
sendSignedPeerRecord = sendSignedPeerRecord,
|
||||
peerStoreCapacity = peerStoreCapacity,
|
||||
)
|
||||
.build()
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
import chronos
|
||||
import results
|
||||
import peerid, stream/connection, transports/transport
|
||||
import peerid, stream/connection, transports/transport, muxers/muxer
|
||||
|
||||
export results
|
||||
|
||||
@@ -65,6 +65,23 @@ method dial*(
|
||||
method addTransport*(self: Dial, transport: Transport) {.base.} =
|
||||
doAssert(false, "[Dial.addTransport] abstract method not implemented!")
|
||||
|
||||
method dialAndUpgrade*(
|
||||
self: Dial, peerId: Opt[PeerId], address: MultiAddress, dir = Direction.Out
|
||||
): Future[Muxer] {.base, async: (raises: [CancelledError]).} =
|
||||
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
|
||||
|
||||
method dialAndUpgrade*(
|
||||
self: Dial, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
|
||||
): Future[Muxer] {.
|
||||
base, async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
|
||||
.} =
|
||||
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
|
||||
|
||||
method negotiateStream*(
|
||||
self: Dial, conn: Connection, protos: seq[string]
|
||||
): Future[Connection] {.base, async: (raises: [CatchableError]).} =
|
||||
doAssert(false, "[Dial.negotiateStream] abstract method not implemented!")
|
||||
|
||||
method tryDial*(
|
||||
self: Dial, peerId: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[MultiAddress]] {.
|
||||
|
||||
@@ -43,7 +43,7 @@ type Dialer* = ref object of Dial
|
||||
peerStore: PeerStore
|
||||
nameResolver: NameResolver
|
||||
|
||||
proc dialAndUpgrade(
|
||||
proc dialAndUpgrade*(
|
||||
self: Dialer,
|
||||
peerId: Opt[PeerId],
|
||||
hostname: string,
|
||||
@@ -139,7 +139,7 @@ proc expandDnsAddr(
|
||||
else:
|
||||
result.add((resolvedAddress, peerId))
|
||||
|
||||
proc dialAndUpgrade(
|
||||
proc dialAndUpgrade*(
|
||||
self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
|
||||
): Future[Muxer] {.
|
||||
async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
|
||||
@@ -284,7 +284,7 @@ method connect*(
|
||||
return
|
||||
(await self.internalConnect(Opt.none(PeerId), @[address], false)).connection.peerId
|
||||
|
||||
proc negotiateStream(
|
||||
proc negotiateStream*(
|
||||
self: Dialer, conn: Connection, protos: seq[string]
|
||||
): Future[Connection] {.async: (raises: [CatchableError]).} =
|
||||
trace "Negotiating stream", conn, protos
|
||||
|
||||
@@ -27,7 +27,7 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
|
||||
quote:
|
||||
for res in `futs`:
|
||||
if res.failed:
|
||||
let exc = res.readError()
|
||||
let exc = res.error
|
||||
# We still don't abort but warn
|
||||
debug "A future has failed, enable trace logging for details",
|
||||
error = exc.name
|
||||
@@ -37,7 +37,7 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
|
||||
for res in `futs`:
|
||||
block check:
|
||||
if res.failed:
|
||||
let exc = res.readError()
|
||||
let exc = res.error
|
||||
for i in 0 ..< `nexclude`:
|
||||
if exc of `exclude`[i]:
|
||||
trace "A future has failed", error = exc.name, description = exc.msg
|
||||
|
||||
@@ -843,6 +843,14 @@ proc init*(
|
||||
res.data.finish()
|
||||
ok(res)
|
||||
|
||||
proc getPart*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] =
|
||||
## Returns the first multiaddress in ``value`` with codec ``codec``
|
||||
for part in ma:
|
||||
let part = ?part
|
||||
if codec == ?part.protoCode:
|
||||
return ok(part)
|
||||
err("no such codec in multiaddress")
|
||||
|
||||
proc getProtocol(name: string): MAProtocol {.inline.} =
|
||||
let mc = MultiCodec.codec(name)
|
||||
if mc != InvalidMultiCodec:
|
||||
@@ -1119,3 +1127,32 @@ proc getRepeatedField*(
|
||||
err(ProtoError.IncorrectBlob)
|
||||
else:
|
||||
ok(true)
|
||||
|
||||
proc areAddrsConsistent*(a, b: MultiAddress): bool =
|
||||
## Checks if two multiaddresses have the same protocol stack.
|
||||
let protosA = a.protocols().get()
|
||||
let protosB = b.protocols().get()
|
||||
if protosA.len != protosB.len:
|
||||
return false
|
||||
|
||||
for idx in 0 ..< protosA.len:
|
||||
let protoA = protosA[idx]
|
||||
let protoB = protosB[idx]
|
||||
|
||||
if protoA != protoB:
|
||||
if idx == 0:
|
||||
# allow DNS ↔ IP at the first component
|
||||
if protoB == multiCodec("dns") or protoB == multiCodec("dnsaddr"):
|
||||
if not (protoA == multiCodec("ip4") or protoA == multiCodec("ip6")):
|
||||
return false
|
||||
elif protoB == multiCodec("dns4"):
|
||||
if protoA != multiCodec("ip4"):
|
||||
return false
|
||||
elif protoB == multiCodec("dns6"):
|
||||
if protoA != multiCodec("ip6"):
|
||||
return false
|
||||
else:
|
||||
return false
|
||||
else:
|
||||
return false
|
||||
true
|
||||
|
||||
@@ -150,6 +150,10 @@ method close*(s: LPChannel) {.async: (raises: []).} =
|
||||
|
||||
trace "Closed channel", s, len = s.len
|
||||
|
||||
method closeWrite*(s: LPChannel) {.async: (raises: []).} =
|
||||
## For mplex, closeWrite is the same as close - it implements half-close
|
||||
await s.close()
|
||||
|
||||
method initStream*(s: LPChannel) =
|
||||
if s.objName.len == 0:
|
||||
s.objName = LPChannelTrackerName
|
||||
|
||||
@@ -95,6 +95,7 @@ proc newStreamInternal*(
|
||||
|
||||
result.peerId = m.connection.peerId
|
||||
result.observedAddr = m.connection.observedAddr
|
||||
result.localAddr = m.connection.localAddr
|
||||
result.transportDir = m.connection.transportDir
|
||||
when defined(libp2p_agents_metrics):
|
||||
result.shortAgent = m.connection.shortAgent
|
||||
|
||||
@@ -54,6 +54,10 @@ method newStream*(
|
||||
.} =
|
||||
raiseAssert("[Muxer.newStream] abstract method not implemented!")
|
||||
|
||||
when defined(libp2p_agents_metrics):
|
||||
method setShortAgent*(m: Muxer, shortAgent: string) {.base, gcsafe.} =
|
||||
m.connection.shortAgent = shortAgent
|
||||
|
||||
method close*(m: Muxer) {.base, async: (raises: []).} =
|
||||
if m.connection != nil:
|
||||
await m.connection.close()
|
||||
|
||||
@@ -217,7 +217,11 @@ method closeImpl*(channel: YamuxChannel) {.async: (raises: []).} =
|
||||
discard
|
||||
await channel.actuallyClose()
|
||||
|
||||
proc clearQueues(channel: YamuxChannel, error: ref CatchableError = nil) =
|
||||
method closeWrite*(channel: YamuxChannel) {.async: (raises: []).} =
|
||||
## For yamux, closeWrite is the same as close - it implements half-close
|
||||
await channel.close()
|
||||
|
||||
proc clearQueues(channel: YamuxChannel, error: ref LPStreamEOFError = nil) =
|
||||
for toSend in channel.sendQueue:
|
||||
if error.isNil():
|
||||
toSend.fut.complete()
|
||||
@@ -511,6 +515,7 @@ proc createStream(
|
||||
stream.initStream()
|
||||
stream.peerId = m.connection.peerId
|
||||
stream.observedAddr = m.connection.observedAddr
|
||||
stream.localAddr = m.connection.localAddr
|
||||
stream.transportDir = m.connection.transportDir
|
||||
when defined(libp2p_agents_metrics):
|
||||
stream.shortAgent = m.connection.shortAgent
|
||||
@@ -529,13 +534,13 @@ method close*(m: Yamux) {.async: (raises: []).} =
|
||||
trace "Closing yamux"
|
||||
let channels = toSeq(m.channels.values())
|
||||
for channel in channels:
|
||||
for toSend in channel.sendQueue:
|
||||
toSend.fut.fail(newLPStreamEOFError())
|
||||
channel.sendQueue = @[]
|
||||
channel.clearQueues(newLPStreamEOFError())
|
||||
channel.recvWindow = 0
|
||||
channel.sendWindow = 0
|
||||
channel.closedLocally = true
|
||||
channel.isReset = true
|
||||
channel.opened = false
|
||||
channel.isClosed = true
|
||||
await channel.remoteClosed()
|
||||
channel.receivedData.fire()
|
||||
try:
|
||||
@@ -607,8 +612,10 @@ method handle*(m: Yamux) {.async: (raises: []).} =
|
||||
if header.length > 0:
|
||||
var buffer = newSeqUninit[byte](header.length)
|
||||
await m.connection.readExactly(addr buffer[0], int(header.length))
|
||||
do:
|
||||
raise newException(YamuxError, "Unknown stream ID: " & $header.streamId)
|
||||
|
||||
# If we do not have a stream, likely we sent a RST and/or closed the stream
|
||||
trace "unknown stream id", id = header.streamId
|
||||
|
||||
continue
|
||||
|
||||
let channel =
|
||||
|
||||
@@ -214,7 +214,7 @@ proc identify*(
|
||||
info.agentVersion.get("").split("/")[0].safeToLowerAscii().get("")
|
||||
if KnownLibP2PAgentsSeq.contains(shortAgent):
|
||||
knownAgent = shortAgent
|
||||
muxer.connection.setShortAgent(knownAgent)
|
||||
muxer.setShortAgent(knownAgent)
|
||||
|
||||
peerStore.updatePeerInfo(info, stream.observedAddr)
|
||||
finally:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
import results
|
||||
import chronos, chronicles
|
||||
import ../../../switch, ../../../multiaddress, ../../../peerid
|
||||
import core
|
||||
import types
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonat"
|
||||
|
||||
@@ -20,9 +20,9 @@ import
|
||||
../../../peerid,
|
||||
../../../utils/[semaphore, future],
|
||||
../../../errors
|
||||
import core
|
||||
import types
|
||||
|
||||
export core
|
||||
export types
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonat"
|
||||
|
||||
@@ -14,11 +14,11 @@ import chronos, metrics
|
||||
import ../../../switch
|
||||
import ../../../wire
|
||||
import client
|
||||
from core import NetworkReachability, AutonatUnreachableError
|
||||
from types import NetworkReachability, AutonatUnreachableError
|
||||
import ../../../utils/heartbeat
|
||||
import ../../../crypto/crypto
|
||||
|
||||
export core.NetworkReachability
|
||||
export NetworkReachability
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonatservice"
|
||||
|
||||
276
libp2p/protocols/connectivity/autonatv2/server.nim
Normal file
276
libp2p/protocols/connectivity/autonatv2/server.nim
Normal file
@@ -0,0 +1,276 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import results
|
||||
import chronos, chronicles
|
||||
import
|
||||
../../../../libp2p/[
|
||||
switch,
|
||||
muxers/muxer,
|
||||
dialer,
|
||||
multiaddress,
|
||||
transports/transport,
|
||||
multicodec,
|
||||
peerid,
|
||||
protobuf/minprotobuf,
|
||||
utils/ipaddr,
|
||||
],
|
||||
../../protocol,
|
||||
./types
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonat v2 server"
|
||||
|
||||
const
|
||||
DefaultDialTimeout: Duration = 15.seconds
|
||||
DefaultAmplificationAttackDialTimeout: Duration = 3.seconds
|
||||
DefaultDialDataSize: uint64 = 50 * 1024 # 50 KiB > 50 KB
|
||||
AutonatV2MsgLpSize: int = 1024
|
||||
# readLp needs to receive more than 4096 bytes (since it's a DialDataResponse) + overhead
|
||||
AutonatV2DialDataResponseLpSize: int = 5000
|
||||
|
||||
type AutonatV2* = ref object of LPProtocol
|
||||
switch*: Switch
|
||||
dialTimeout: Duration
|
||||
dialDataSize: uint64
|
||||
amplificationAttackTimeout: Duration
|
||||
|
||||
proc sendDialResponse(
|
||||
conn: Connection,
|
||||
status: ResponseStatus,
|
||||
addrIdx: Opt[AddrIdx] = Opt.none(AddrIdx),
|
||||
dialStatus: Opt[DialStatus] = Opt.none(DialStatus),
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
await conn.writeLp(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialResponse,
|
||||
dialResp: DialResponse(status: status, addrIdx: addrIdx, dialStatus: dialStatus),
|
||||
).encode().buffer
|
||||
)
|
||||
|
||||
proc findObservedIPAddr*(
|
||||
conn: Connection, req: DialRequest
|
||||
): Future[Opt[MultiAddress]] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let observedAddr = conn.observedAddr.valueOr:
|
||||
await conn.sendDialResponse(ResponseStatus.EInternalError)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
let isRelayed = observedAddr.contains(multiCodec("p2p-circuit")).valueOr:
|
||||
error "Invalid observed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
if isRelayed:
|
||||
error "Invalid observed address: relayed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
let hostIp = observedAddr[0].valueOr:
|
||||
error "Invalid observed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EInternalError)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
return Opt.some(hostIp)
|
||||
|
||||
proc dialBack(
|
||||
conn: Connection, nonce: Nonce
|
||||
): Future[DialStatus] {.
|
||||
async: (raises: [CancelledError, DialFailedError, LPStreamError])
|
||||
.} =
|
||||
try:
|
||||
# send dial back
|
||||
await conn.writeLp(DialBack(nonce: nonce).encode().buffer)
|
||||
|
||||
# receive DialBackResponse
|
||||
let dialBackResp = DialBackResponse.decode(
|
||||
initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))
|
||||
).valueOr:
|
||||
error "DialBack failed, could not decode DialBackResponse"
|
||||
return DialStatus.EDialBackError
|
||||
except LPStreamRemoteClosedError as exc:
|
||||
# failed because of nonce error (remote reset the stream): EDialBackError
|
||||
error "DialBack failed, remote closed the connection", description = exc.msg
|
||||
return DialStatus.EDialBackError
|
||||
|
||||
# TODO: failed because of client or server resources: EDialError
|
||||
|
||||
trace "DialBack successful"
|
||||
return DialStatus.Ok
|
||||
|
||||
proc handleDialDataResponses(
|
||||
self: AutonatV2, conn: Connection
|
||||
) {.async: (raises: [CancelledError, AutonatV2Error, LPStreamError]).} =
|
||||
var dataReceived: uint64 = 0
|
||||
|
||||
while dataReceived < self.dialDataSize:
|
||||
let msg = AutonatV2Msg.decode(
|
||||
initProtoBuffer(await conn.readLp(AutonatV2DialDataResponseLpSize))
|
||||
).valueOr:
|
||||
raise newException(AutonatV2Error, "Received malformed message")
|
||||
debug "Received message", msgType = $msg.msgType
|
||||
if msg.msgType != MsgType.DialDataResponse:
|
||||
raise
|
||||
newException(AutonatV2Error, "Expecting DialDataResponse, got " & $msg.msgType)
|
||||
let resp = msg.dialDataResp
|
||||
dataReceived += resp.data.len.uint64
|
||||
debug "received data",
|
||||
dataReceived = resp.data.len.uint64, totalDataReceived = dataReceived
|
||||
|
||||
proc amplificationAttackPrevention(
|
||||
self: AutonatV2, conn: Connection, addrIdx: AddrIdx
|
||||
): Future[bool] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
# send DialDataRequest
|
||||
await conn.writeLp(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialDataRequest,
|
||||
dialDataReq: DialDataRequest(addrIdx: addrIdx, numBytes: self.dialDataSize),
|
||||
).encode().buffer
|
||||
)
|
||||
|
||||
# recieve DialDataResponses until we're satisfied
|
||||
try:
|
||||
if not await self.handleDialDataResponses(conn).withTimeout(self.dialTimeout):
|
||||
error "Amplification attack prevention timeout",
|
||||
timeout = self.amplificationAttackTimeout, peer = conn.peerId
|
||||
return false
|
||||
except AutonatV2Error as exc:
|
||||
error "Amplification attack prevention failed", description = exc.msg
|
||||
return false
|
||||
|
||||
return true
|
||||
|
||||
proc canDial(self: AutonatV2, addrs: MultiAddress): bool =
|
||||
let (ipv4Support, ipv6Support) = self.switch.peerInfo.listenAddrs.ipSupport()
|
||||
addrs[0].withValue(addrIp):
|
||||
if IP4.match(addrIp) and not ipv4Support:
|
||||
return false
|
||||
if IP6.match(addrIp) and not ipv6Support:
|
||||
return false
|
||||
try:
|
||||
if isPrivate($addrIp):
|
||||
return false
|
||||
except ValueError:
|
||||
warn "Unable to parse IP address, skipping", addrs = $addrIp
|
||||
return false
|
||||
for t in self.switch.transports:
|
||||
if t.handles(addrs):
|
||||
return true
|
||||
return false
|
||||
|
||||
proc forceNewConnection(
|
||||
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[Connection]] {.async: (raises: [CancelledError]).} =
|
||||
## Bypasses connManager to force a new connection to ``pid``
|
||||
## instead of reusing a preexistent one
|
||||
try:
|
||||
let mux = await self.switch.dialer.dialAndUpgrade(Opt.some(pid), addrs)
|
||||
if mux.isNil():
|
||||
return Opt.none(Connection)
|
||||
return Opt.some(
|
||||
await self.switch.negotiateStream(
|
||||
await mux.newStream(), @[$AutonatV2Codec.DialBack]
|
||||
)
|
||||
)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError:
|
||||
return Opt.none(Connection)
|
||||
|
||||
proc chooseDialAddr(
|
||||
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[(Connection, AddrIdx)]] {.async: (raises: [CancelledError]).} =
|
||||
for i, ma in addrs:
|
||||
if self.canDial(ma):
|
||||
debug "Trying to dial", chosenAddrs = ma, addrIdx = i
|
||||
let conn = (await self.forceNewConnection(pid, @[ma])).valueOr:
|
||||
return Opt.none((Connection, AddrIdx))
|
||||
return Opt.some((conn, i.AddrIdx))
|
||||
return Opt.none((Connection, AddrIdx))
|
||||
|
||||
proc handleDialRequest(
|
||||
self: AutonatV2, conn: Connection, req: DialRequest
|
||||
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let observedIPAddr = (await conn.findObservedIPAddr(req)).valueOr:
|
||||
error "Could not find observed IP address"
|
||||
return
|
||||
|
||||
let (dialBackConn, addrIdx) = (await self.chooseDialAddr(conn.peerId, req.addrs)).valueOr:
|
||||
error "No dialable addresses found"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return
|
||||
defer:
|
||||
await dialBackConn.close()
|
||||
|
||||
# if observed address for peer is not in address list to try
|
||||
# then we perform Amplification Attack Prevention
|
||||
if not ipAddrMatches(observedIPAddr, req.addrs):
|
||||
debug "Starting amplification attack prevention",
|
||||
observedIPAddr = observedIPAddr, testAddr = req.addrs[addrIdx]
|
||||
# send DialDataRequest and wait until dataReceived is enough
|
||||
if not await self.amplificationAttackPrevention(conn, addrIdx):
|
||||
return
|
||||
|
||||
debug "Sending DialBack",
|
||||
nonce = req.nonce, addrIdx = addrIdx, addr = req.addrs[addrIdx]
|
||||
|
||||
let dialStatus = await dialBackConn.dialBack(req.nonce)
|
||||
|
||||
await conn.sendDialResponse(
|
||||
ResponseStatus.Ok, addrIdx = Opt.some(addrIdx), dialStatus = Opt.some(dialStatus)
|
||||
)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[AutonatV2],
|
||||
switch: Switch,
|
||||
dialTimeout: Duration = DefaultDialTimeout,
|
||||
dialDataSize: uint64 = DefaultDialDataSize,
|
||||
amplificationAttackTimeout: Duration = DefaultAmplificationAttackDialTimeout,
|
||||
): T =
|
||||
let autonatV2 = T(
|
||||
switch: switch,
|
||||
dialTimeout: dialTimeout,
|
||||
dialDataSize: dialDataSize,
|
||||
amplificationAttackTimeout: amplificationAttackTimeout,
|
||||
)
|
||||
proc handleStream(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
defer:
|
||||
await conn.close()
|
||||
|
||||
let msg =
|
||||
try:
|
||||
AutonatV2Msg.decode(initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))).valueOr:
|
||||
trace "Unable to decode AutonatV2Msg"
|
||||
return
|
||||
except LPStreamError as exc:
|
||||
debug "Could not receive AutonatV2Msg", description = exc.msg
|
||||
return
|
||||
|
||||
debug "Received message", msgType = $msg.msgType
|
||||
if msg.msgType != MsgType.DialRequest:
|
||||
debug "Expecting DialRequest", receivedMsgType = msg.msgType
|
||||
return
|
||||
|
||||
try:
|
||||
await autonatV2.handleDialRequest(conn, msg.dialReq)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except LPStreamRemoteClosedError as exc:
|
||||
debug "Connection closed by peer", description = exc.msg, peer = conn.peerId
|
||||
except LPStreamError as exc:
|
||||
debug "Stream Error", description = exc.msg
|
||||
except DialFailedError as exc:
|
||||
debug "Could not dial peer", description = exc.msg, peer = conn.peerId
|
||||
|
||||
autonatV2.handler = handleStream
|
||||
autonatV2.codec = $AutonatV2Codec.DialRequest
|
||||
autonatV2
|
||||
254
libp2p/protocols/connectivity/autonatv2/types.nim
Normal file
254
libp2p/protocols/connectivity/autonatv2/types.nim
Normal file
@@ -0,0 +1,254 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import results, chronos, chronicles
|
||||
import
|
||||
../../../multiaddress, ../../../peerid, ../../../protobuf/minprotobuf, ../../../switch
|
||||
from ../autonat/types import NetworkReachability
|
||||
|
||||
type
|
||||
AutonatV2Codec* {.pure.} = enum
|
||||
DialRequest = "/libp2p/autonat/2/dial-request"
|
||||
DialBack = "/libp2p/autonat/2/dial-back"
|
||||
|
||||
AutonatV2Response* = object
|
||||
reachability*: NetworkReachability
|
||||
dialResp*: DialResponse
|
||||
addrs*: Opt[MultiAddress]
|
||||
|
||||
AutonatV2Error* = object of LPError
|
||||
|
||||
Nonce* = uint64
|
||||
|
||||
AddrIdx* = uint32
|
||||
|
||||
NumBytes* = uint64
|
||||
|
||||
MsgType* {.pure.} = enum
|
||||
# DialBack and DialBackResponse are not defined as AutonatV2Msg as per the spec
|
||||
# likely because they are expected in response to some other message
|
||||
DialRequest
|
||||
DialResponse
|
||||
DialDataRequest
|
||||
DialDataResponse
|
||||
|
||||
ResponseStatus* {.pure.} = enum
|
||||
EInternalError = 0
|
||||
ERequestRejected = 100
|
||||
EDialRefused = 101
|
||||
Ok = 200
|
||||
|
||||
DialBackStatus* {.pure.} = enum
|
||||
Ok = 0
|
||||
|
||||
DialStatus* {.pure.} = enum
|
||||
Unused = 0
|
||||
EDialError = 100
|
||||
EDialBackError = 101
|
||||
Ok = 200
|
||||
|
||||
DialRequest* = object
|
||||
addrs*: seq[MultiAddress]
|
||||
nonce*: Nonce
|
||||
|
||||
DialResponse* = object
|
||||
status*: ResponseStatus
|
||||
addrIdx*: Opt[AddrIdx]
|
||||
dialStatus*: Opt[DialStatus]
|
||||
|
||||
DialBack* = object
|
||||
nonce*: Nonce
|
||||
|
||||
DialBackResponse* = object
|
||||
status*: DialBackStatus
|
||||
|
||||
DialDataRequest* = object
|
||||
addrIdx*: AddrIdx
|
||||
numBytes*: NumBytes
|
||||
|
||||
DialDataResponse* = object
|
||||
data*: seq[byte]
|
||||
|
||||
AutonatV2Msg* = object
|
||||
case msgType*: MsgType
|
||||
of MsgType.DialRequest:
|
||||
dialReq*: DialRequest
|
||||
of MsgType.DialResponse:
|
||||
dialResp*: DialResponse
|
||||
of MsgType.DialDataRequest:
|
||||
dialDataReq*: DialDataRequest
|
||||
of MsgType.DialDataResponse:
|
||||
dialDataResp*: DialDataResponse
|
||||
|
||||
# DialRequest
|
||||
proc encode*(dialReq: DialRequest): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
for ma in dialReq.addrs:
|
||||
encoded.write(1, ma.data.buffer)
|
||||
encoded.write(2, dialReq.nonce)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialRequest], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
addrs: seq[MultiAddress]
|
||||
nonce: Nonce
|
||||
if not ?pb.getRepeatedField(1, addrs).toOpt():
|
||||
return Opt.none(T)
|
||||
if not ?pb.getField(2, nonce).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(addrs: addrs, nonce: nonce))
|
||||
|
||||
# DialResponse
|
||||
proc encode*(dialResp: DialResponse): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialResp.status.uint)
|
||||
# minprotobuf casts uses float64 for fixed64 fields
|
||||
dialResp.addrIdx.withValue(addrIdx):
|
||||
encoded.write(2, addrIdx)
|
||||
dialResp.dialStatus.withValue(dialStatus):
|
||||
encoded.write(3, dialStatus.uint)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialResponse], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
status: uint
|
||||
addrIdx: AddrIdx
|
||||
dialStatus: uint
|
||||
|
||||
if not ?pb.getField(1, status).toOpt():
|
||||
return Opt.none(T)
|
||||
|
||||
var optAddrIdx = Opt.none(AddrIdx)
|
||||
if ?pb.getField(2, addrIdx).toOpt():
|
||||
optAddrIdx = Opt.some(addrIdx)
|
||||
|
||||
var optDialStatus = Opt.none(DialStatus)
|
||||
if ?pb.getField(3, dialStatus).toOpt():
|
||||
optDialStatus = Opt.some(cast[DialStatus](dialStatus))
|
||||
|
||||
Opt.some(
|
||||
T(
|
||||
status: cast[ResponseStatus](status),
|
||||
addrIdx: optAddrIdx,
|
||||
dialStatus: optDialStatus,
|
||||
)
|
||||
)
|
||||
|
||||
# DialBack
|
||||
proc encode*(dialBack: DialBack): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialBack.nonce)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialBack], pb: ProtoBuffer): Opt[T] =
|
||||
var nonce: Nonce
|
||||
if not ?pb.getField(1, nonce).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(nonce: nonce))
|
||||
|
||||
# DialBackResponse
|
||||
proc encode*(dialBackResp: DialBackResponse): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialBackResp.status.uint)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialBackResponse], pb: ProtoBuffer): Opt[T] =
|
||||
var status: uint
|
||||
if not ?pb.getField(1, status).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(status: cast[DialBackStatus](status)))
|
||||
|
||||
# DialDataRequest
|
||||
proc encode*(dialDataReq: DialDataRequest): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialDataReq.addrIdx)
|
||||
encoded.write(2, dialDataReq.numBytes)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialDataRequest], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
addrIdx: AddrIdx
|
||||
numBytes: NumBytes
|
||||
if not ?pb.getField(1, addrIdx).toOpt():
|
||||
return Opt.none(T)
|
||||
if not ?pb.getField(2, numBytes).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(addrIdx: addrIdx, numBytes: numBytes))
|
||||
|
||||
# DialDataResponse
|
||||
proc encode*(dialDataResp: DialDataResponse): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialDataResp.data)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialDataResponse], pb: ProtoBuffer): Opt[T] =
|
||||
var data: seq[byte]
|
||||
if not ?pb.getField(1, data).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(data: data))
|
||||
|
||||
proc protoField(msgType: MsgType): int =
|
||||
case msgType
|
||||
of MsgType.DialRequest: 1.int
|
||||
of MsgType.DialResponse: 2.int
|
||||
of MsgType.DialDataRequest: 3.int
|
||||
of MsgType.DialDataResponse: 4.int
|
||||
|
||||
# AutonatV2Msg
|
||||
proc encode*(msg: AutonatV2Msg): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
case msg.msgType
|
||||
of MsgType.DialRequest:
|
||||
encoded.write(MsgType.DialRequest.protoField, msg.dialReq.encode())
|
||||
of MsgType.DialResponse:
|
||||
encoded.write(MsgType.DialResponse.protoField, msg.dialResp.encode())
|
||||
of MsgType.DialDataRequest:
|
||||
encoded.write(MsgType.DialDataRequest.protoField, msg.dialDataReq.encode())
|
||||
of MsgType.DialDataResponse:
|
||||
encoded.write(MsgType.DialDataResponse.protoField, msg.dialDataResp.encode())
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[AutonatV2Msg], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
msgTypeOrd: uint32
|
||||
msg: ProtoBuffer
|
||||
|
||||
if ?pb.getField(MsgType.DialRequest.protoField, msg).toOpt():
|
||||
let dialReq = DialRequest.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(AutonatV2Msg(msgType: MsgType.DialRequest, dialReq: dialReq))
|
||||
elif ?pb.getField(MsgType.DialResponse.protoField, msg).toOpt():
|
||||
let dialResp = DialResponse.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(AutonatV2Msg(msgType: MsgType.DialResponse, dialResp: dialResp))
|
||||
elif ?pb.getField(MsgType.DialDataRequest.protoField, msg).toOpt():
|
||||
let dialDataReq = DialDataRequest.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(AutonatV2Msg(msgType: MsgType.DialDataRequest, dialDataReq: dialDataReq))
|
||||
elif ?pb.getField(MsgType.DialDataResponse.protoField, msg).toOpt():
|
||||
let dialDataResp = DialDataResponse.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(
|
||||
AutonatV2Msg(msgType: MsgType.DialDataResponse, dialDataResp: dialDataResp)
|
||||
)
|
||||
else:
|
||||
Opt.none(AutonatV2Msg)
|
||||
|
||||
# Custom `==` is needed to compare since AutonatV2Msg is a case object
|
||||
proc `==`*(a, b: AutonatV2Msg): bool =
|
||||
a.msgType == b.msgType and a.encode() == b.encode()
|
||||
47
libp2p/protocols/connectivity/autonatv2/utils.nim
Normal file
47
libp2p/protocols/connectivity/autonatv2/utils.nim
Normal file
@@ -0,0 +1,47 @@
|
||||
{.push raises: [].}
|
||||
|
||||
import results
|
||||
import chronos
|
||||
import
|
||||
../../protocol,
|
||||
../../../switch,
|
||||
../../../multiaddress,
|
||||
../../../multicodec,
|
||||
../../../peerid,
|
||||
../../../protobuf/minprotobuf,
|
||||
../autonat/service,
|
||||
./types
|
||||
|
||||
proc asNetworkReachability*(self: DialResponse): NetworkReachability =
|
||||
if self.status == EInternalError:
|
||||
return Unknown
|
||||
if self.status == ERequestRejected:
|
||||
return Unknown
|
||||
if self.status == EDialRefused:
|
||||
return Unknown
|
||||
|
||||
# if got here it means a dial was attempted
|
||||
let dialStatus = self.dialStatus.valueOr:
|
||||
return Unknown
|
||||
if dialStatus == Unused:
|
||||
return Unknown
|
||||
if dialStatus == EDialError:
|
||||
return NotReachable
|
||||
if dialStatus == EDialBackError:
|
||||
return NotReachable
|
||||
return Reachable
|
||||
|
||||
proc asAutonatV2Response*(
|
||||
self: DialResponse, testAddrs: seq[MultiAddress]
|
||||
): AutonatV2Response =
|
||||
let addrIdx = self.addrIdx.valueOr:
|
||||
return AutonatV2Response(
|
||||
reachability: self.asNetworkReachability(),
|
||||
dialResp: self,
|
||||
addrs: Opt.none(MultiAddress),
|
||||
)
|
||||
AutonatV2Response(
|
||||
reachability: self.asNetworkReachability(),
|
||||
dialResp: self,
|
||||
addrs: Opt.some(testAddrs[addrIdx]),
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
import chronos
|
||||
import chronicles
|
||||
import sequtils
|
||||
import sets
|
||||
import ../../peerid
|
||||
import ./consts
|
||||
import ./xordistance
|
||||
@@ -9,22 +10,95 @@ import ./lookupstate
|
||||
import ./requests
|
||||
import ./keys
|
||||
import ../protocol
|
||||
import ../../switch
|
||||
import ./protobuf
|
||||
import ../../switch
|
||||
import ../../multihash
|
||||
import ../../utils/heartbeat
|
||||
import std/[options, tables]
|
||||
import std/[times, options, tables]
|
||||
import results
|
||||
|
||||
logScope:
|
||||
topics = "kad-dht"
|
||||
|
||||
type EntryKey* = object
|
||||
data: seq[byte]
|
||||
|
||||
proc init*(T: typedesc[EntryKey], inner: seq[byte]): EntryKey {.gcsafe, raises: [].} =
|
||||
EntryKey(data: inner)
|
||||
|
||||
type EntryValue* = object
|
||||
data*: seq[byte] # public because needed for tests
|
||||
|
||||
proc init*(
|
||||
T: typedesc[EntryValue], inner: seq[byte]
|
||||
): EntryValue {.gcsafe, raises: [].} =
|
||||
EntryValue(data: inner)
|
||||
|
||||
type TimeStamp* = object
|
||||
# Currently a string, because for some reason, that's what is chosen at the protobuf level
|
||||
# TODO: convert between RFC3339 strings and use of integers (i.e. the _correct_ way)
|
||||
ts*: string # only public because needed for tests
|
||||
|
||||
type EntryRecord* = object
|
||||
value*: EntryValue # only public because needed for tests
|
||||
time*: TimeStamp # only public because needed for tests
|
||||
|
||||
proc init*(
|
||||
T: typedesc[EntryRecord], value: EntryValue, time: Option[TimeStamp]
|
||||
): EntryRecord {.gcsafe, raises: [].} =
|
||||
EntryRecord(value: value, time: time.get(TimeStamp(ts: $times.now().utc)))
|
||||
|
||||
type LocalTable* = object
|
||||
entries*: Table[EntryKey, EntryRecord] # public because needed for tests
|
||||
|
||||
proc init(self: typedesc[LocalTable]): LocalTable {.raises: [].} =
|
||||
LocalTable()
|
||||
|
||||
type EntryCandidate* = object
|
||||
key*: EntryKey
|
||||
value*: EntryValue
|
||||
|
||||
type ValidatedEntry* = object
|
||||
key: EntryKey
|
||||
value: EntryValue
|
||||
|
||||
proc init*(
|
||||
T: typedesc[ValidatedEntry], key: EntryKey, value: EntryValue
|
||||
): ValidatedEntry {.gcsafe, raises: [].} =
|
||||
ValidatedEntry(key: key, value: value)
|
||||
|
||||
type EntryValidator* = ref object of RootObj
|
||||
method isValid*(
|
||||
self: EntryValidator, key: EntryKey, val: EntryValue
|
||||
): bool {.base, raises: [], gcsafe.} =
|
||||
doAssert(false, "unimplimented base method")
|
||||
|
||||
type EntrySelector* = ref object of RootObj
|
||||
method select*(
|
||||
self: EntrySelector, cand: EntryRecord, others: seq[EntryRecord]
|
||||
): Result[EntryRecord, string] {.base, raises: [], gcsafe.} =
|
||||
doAssert(false, "EntrySelection base not implemented")
|
||||
|
||||
type KadDHT* = ref object of LPProtocol
|
||||
switch: Switch
|
||||
rng: ref HmacDrbgContext
|
||||
rtable*: RoutingTable
|
||||
maintenanceLoop: Future[void]
|
||||
dataTable*: LocalTable
|
||||
entryValidator: EntryValidator
|
||||
entrySelector: EntrySelector
|
||||
|
||||
proc insert*(
|
||||
self: var LocalTable, value: sink ValidatedEntry, time: TimeStamp
|
||||
) {.raises: [].} =
|
||||
debug "local table insertion", key = value.key.data, value = value.value.data
|
||||
self.entries[value.key] = EntryRecord(value: value.value, time: time)
|
||||
|
||||
const MaxMsgSize = 4096
|
||||
# Forward declaration
|
||||
proc findNode*(
|
||||
kad: KadDHT, targetId: Key
|
||||
): Future[seq[PeerId]] {.async: (raises: [CancelledError]).}
|
||||
|
||||
proc sendFindNode(
|
||||
kad: KadDHT, peerId: PeerId, addrs: seq[MultiAddress], targetId: Key
|
||||
@@ -36,16 +110,13 @@ proc sendFindNode(
|
||||
await kad.switch.dial(peerId, KadCodec)
|
||||
else:
|
||||
await kad.switch.dial(peerId, addrs, KadCodec)
|
||||
|
||||
defer:
|
||||
await conn.close()
|
||||
|
||||
let msg = Message(msgType: MessageType.findNode, key: some(targetId.getBytes()))
|
||||
|
||||
await conn.writeLp(msg.encode().buffer)
|
||||
|
||||
let reply = Message.decode(await conn.readLp(MaxMsgSize)).tryGet()
|
||||
|
||||
if reply.msgType != MessageType.findNode:
|
||||
raise newException(ValueError, "unexpected message type in reply: " & $reply)
|
||||
|
||||
@@ -68,30 +139,92 @@ proc waitRepliesOrTimeouts(
|
||||
|
||||
return (receivedReplies, failedPeers)
|
||||
|
||||
proc dispatchPutVal(
|
||||
kad: KadDHT, peer: PeerId, entry: ValidatedEntry
|
||||
): Future[void] {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let conn = await kad.switch.dial(peer, KadCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
let msg = Message(
|
||||
msgType: MessageType.putValue,
|
||||
record: some(Record(key: some(entry.key.data), value: some(entry.value.data))),
|
||||
)
|
||||
await conn.writeLp(msg.encode().buffer)
|
||||
|
||||
let reply = Message.decode(await conn.readLp(MaxMsgSize)).valueOr:
|
||||
# todo log this more meaningfully
|
||||
error "putValue reply decode fail", error = error, conn = conn
|
||||
return
|
||||
if reply != msg:
|
||||
error "unexpected change between msg and reply: ",
|
||||
msg = msg, reply = reply, conn = conn
|
||||
|
||||
proc putValue*(
|
||||
kad: KadDHT, entKey: EntryKey, value: EntryValue, timeout: Option[int]
|
||||
): Future[Result[void, string]] {.async: (raises: [CancelledError]), gcsafe.} =
|
||||
if not kad.entryValidator.isValid(entKey, value):
|
||||
return err("invalid key/value pair")
|
||||
|
||||
let others: seq[EntryRecord] =
|
||||
if entKey in kad.dataTable.entries:
|
||||
@[kad.dataTable.entries.getOrDefault(entKey)]
|
||||
else:
|
||||
@[]
|
||||
|
||||
let candAsRec = EntryRecord.init(value, none(TimeStamp))
|
||||
let confirmedRec = kad.entrySelector.select(candAsRec, others).valueOr:
|
||||
error "application provided selector error (local)", msg = error
|
||||
return err(error)
|
||||
trace "local putval", candidate = candAsRec, others = others, selected = confirmedRec
|
||||
let validEnt = ValidatedEntry.init(entKey, confirmedRec.value)
|
||||
|
||||
let peers = await kad.findNode(entKey.data.toKey())
|
||||
# We first prime the sends so the data is ready to go
|
||||
let rpcBatch = peers.mapIt(kad.dispatchPutVal(it, validEnt))
|
||||
# then we do the `move`, as insert takes the data as `sink`
|
||||
kad.dataTable.insert(validEnt, confirmedRec.time)
|
||||
try:
|
||||
# now that the all the data is where it needs to be in memory, we can dispatch the
|
||||
# RPCs
|
||||
await rpcBatch.allFutures().wait(chronos.seconds(timeout.get(5)))
|
||||
|
||||
# It's quite normal for the dispatch to timeout, as it would require all calls to get
|
||||
# their response. Downstream users may desire some sort of functionality in the
|
||||
# future to get rpc telemetry, but in the meantime, we just move on...
|
||||
except AsyncTimeoutError:
|
||||
discard
|
||||
return results.ok()
|
||||
|
||||
# Helper function forward declaration
|
||||
proc checkConvergence(state: LookupState, me: PeerId): bool {.raises: [], gcsafe.}
|
||||
|
||||
proc findNode*(
|
||||
kad: KadDHT, targetId: Key
|
||||
): Future[seq[PeerId]] {.async: (raises: [CancelledError]).} =
|
||||
## Node lookup. Iteratively search for the k closest peers to a target ID.
|
||||
## Not necessarily will return the target itself
|
||||
|
||||
#debug "findNode", target = target
|
||||
# TODO: should it return a single peer instead? read spec
|
||||
|
||||
var initialPeers = kad.rtable.findClosestPeers(targetId, DefaultReplic)
|
||||
var state = LookupState.init(targetId, initialPeers)
|
||||
var state = LookupState.init(targetId, initialPeers, kad.rtable.hasher)
|
||||
var addrTable: Table[PeerId, seq[MultiAddress]] =
|
||||
initTable[PeerId, seq[MultiAddress]]()
|
||||
|
||||
while not state.done:
|
||||
let toQuery = state.selectAlphaPeers()
|
||||
debug "queries", list = toQuery.mapIt(it.shortLog()), addrTab = addrTable
|
||||
var pendingFutures = initTable[PeerId, Future[Message]]()
|
||||
|
||||
for peer in toQuery:
|
||||
if pendingFutures.hasKey(peer):
|
||||
continue
|
||||
|
||||
# TODO: pending futures always empty here, no?
|
||||
for peer in toQuery.filterIt(
|
||||
kad.switch.peerInfo.peerId != it or pendingFutures.hasKey(it)
|
||||
):
|
||||
state.markPending(peer)
|
||||
|
||||
pendingFutures[peer] = kad
|
||||
.sendFindNode(peer, addrTable.getOrDefault(peer, @[]), targetId)
|
||||
.wait(5.seconds)
|
||||
.wait(chronos.seconds(5))
|
||||
|
||||
state.activeQueries.inc
|
||||
|
||||
@@ -99,23 +232,58 @@ proc findNode*(
|
||||
|
||||
for msg in successfulReplies:
|
||||
for peer in msg.closerPeers:
|
||||
addrTable[PeerId.init(peer.id).get()] = peer.addrs
|
||||
let pid = PeerId.init(peer.id)
|
||||
if not pid.isOk:
|
||||
error "PeerId init went bad. this is unusual", data = peer.id
|
||||
continue
|
||||
addrTable[pid.get()] = peer.addrs
|
||||
state.updateShortlist(
|
||||
msg,
|
||||
proc(p: PeerInfo) =
|
||||
discard kad.rtable.insert(p.peerId)
|
||||
kad.switch.peerStore[AddressBook][p.peerId] = p.addrs
|
||||
# Nodes might return different addresses for a peer, so we append instead of replacing
|
||||
var existingAddresses =
|
||||
kad.switch.peerStore[AddressBook][p.peerId].toHashSet()
|
||||
for a in p.addrs:
|
||||
existingAddresses.incl(a)
|
||||
kad.switch.peerStore[AddressBook][p.peerId] = existingAddresses.toSeq()
|
||||
# TODO: add TTL to peerstore, otherwise we can spam it with junk
|
||||
,
|
||||
kad.rtable.hasher,
|
||||
)
|
||||
|
||||
for timedOut in timedOutPeers:
|
||||
state.markFailed(timedOut)
|
||||
|
||||
state.done = state.checkConvergence()
|
||||
# Check for covergence: no active queries, and no other peers to be selected
|
||||
state.done = checkConvergence(state, kad.switch.peerInfo.peerId)
|
||||
|
||||
return state.selectClosestK()
|
||||
|
||||
proc findPeer*(
|
||||
kad: KadDHT, peer: PeerId
|
||||
): Future[Result[PeerInfo, string]] {.async: (raises: [CancelledError]).} =
|
||||
## Walks the key space until it finds candidate addresses for a peer Id
|
||||
|
||||
if kad.switch.peerInfo.peerId == peer:
|
||||
# Looking for yourself.
|
||||
return ok(kad.switch.peerInfo)
|
||||
|
||||
if kad.switch.isConnected(peer):
|
||||
# Return known info about already connected peer
|
||||
return ok(PeerInfo(peerId: peer, addrs: kad.switch.peerStore[AddressBook][peer]))
|
||||
|
||||
let foundNodes = await kad.findNode(peer.toKey())
|
||||
if not foundNodes.contains(peer):
|
||||
return err("peer not found")
|
||||
|
||||
return ok(PeerInfo(peerId: peer, addrs: kad.switch.peerStore[AddressBook][peer]))
|
||||
|
||||
proc checkConvergence(state: LookupState, me: PeerId): bool {.raises: [], gcsafe.} =
|
||||
let ready = state.activeQueries == 0
|
||||
let noNew = selectAlphaPeers(state).filterIt(me != it).len == 0
|
||||
return ready and noNew
|
||||
|
||||
proc bootstrap*(
|
||||
kad: KadDHT, bootstrapNodes: seq[PeerInfo]
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
@@ -123,28 +291,38 @@ proc bootstrap*(
|
||||
try:
|
||||
await kad.switch.connect(b.peerId, b.addrs)
|
||||
debug "connected to bootstrap peer", peerId = b.peerId
|
||||
except CatchableError as e:
|
||||
error "failed to connect to bootstrap peer", peerId = b.peerId, error = e.msg
|
||||
except DialFailedError as e:
|
||||
# at some point will want to bubble up a Result[void, SomeErrorEnum]
|
||||
error "failed to dial to bootstrap peer", peerId = b.peerId, error = e.msg
|
||||
continue
|
||||
|
||||
try:
|
||||
let msg =
|
||||
await kad.sendFindNode(b.peerId, b.addrs, kad.rtable.selfId).wait(5.seconds)
|
||||
for peer in msg.closerPeers:
|
||||
let p = PeerId.init(peer.id).tryGet()
|
||||
discard kad.rtable.insert(p)
|
||||
let msg =
|
||||
try:
|
||||
await kad.sendFindNode(b.peerId, b.addrs, kad.rtable.selfId).wait(
|
||||
chronos.seconds(5)
|
||||
)
|
||||
except CatchableError as e:
|
||||
debug "send find node exception during bootstrap",
|
||||
target = b.peerId, addrs = b.addrs, err = e.msg
|
||||
continue
|
||||
for peer in msg.closerPeers:
|
||||
let p = PeerId.init(peer.id).valueOr:
|
||||
debug "invalid peer id received", error = error
|
||||
continue
|
||||
discard kad.rtable.insert(p)
|
||||
try:
|
||||
kad.switch.peerStore[AddressBook][p] = peer.addrs
|
||||
except:
|
||||
error "this is here because an ergonomic means of keying into a table without exceptions is unknown"
|
||||
|
||||
# bootstrap node replied succesfully. Adding to routing table
|
||||
discard kad.rtable.insert(b.peerId)
|
||||
except CatchableError as e:
|
||||
error "bootstrap failed for peer", peerId = b.peerId, exc = e.msg
|
||||
# bootstrap node replied succesfully. Adding to routing table
|
||||
discard kad.rtable.insert(b.peerId)
|
||||
|
||||
try:
|
||||
# Adding some random node to prepopulate the table
|
||||
discard await kad.findNode(PeerId.random(kad.rng).tryGet().toKey())
|
||||
info "bootstrap lookup complete"
|
||||
except CatchableError as e:
|
||||
error "bootstrap lookup failed", error = e.msg
|
||||
let key = PeerId.random(kad.rng).valueOr:
|
||||
doAssert(false, "this should never happen")
|
||||
return
|
||||
discard await kad.findNode(key.toKey())
|
||||
info "bootstrap lookup complete"
|
||||
|
||||
proc refreshBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
|
||||
for i in 0 ..< kad.rtable.buckets.len:
|
||||
@@ -153,49 +331,121 @@ proc refreshBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
|
||||
discard await kad.findNode(randomKey)
|
||||
|
||||
proc maintainBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "refresh buckets", 10.minutes:
|
||||
heartbeat "refresh buckets", chronos.minutes(10):
|
||||
await kad.refreshBuckets()
|
||||
|
||||
proc new*(
|
||||
T: typedesc[KadDHT], switch: Switch, rng: ref HmacDrbgContext = newRng()
|
||||
T: typedesc[KadDHT],
|
||||
switch: Switch,
|
||||
validator: EntryValidator,
|
||||
entrySelector: EntrySelector,
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
): T {.raises: [].} =
|
||||
var rtable = RoutingTable.init(switch.peerInfo.peerId.toKey())
|
||||
let kad = T(rng: rng, switch: switch, rtable: rtable)
|
||||
var rtable = RoutingTable.init(switch.peerInfo.peerId.toKey(), Opt.none(XorDHasher))
|
||||
let kad = T(
|
||||
rng: rng,
|
||||
switch: switch,
|
||||
rtable: rtable,
|
||||
dataTable: LocalTable.init(),
|
||||
entryValidator: validator,
|
||||
entrySelector: entrySelector,
|
||||
)
|
||||
|
||||
kad.codec = KadCodec
|
||||
kad.handler = proc(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
while not conn.atEof:
|
||||
let
|
||||
buf = await conn.readLp(MaxMsgSize)
|
||||
msg = Message.decode(buf).tryGet()
|
||||
|
||||
case msg.msgType
|
||||
of MessageType.findNode:
|
||||
let targetIdBytes = msg.key.get()
|
||||
let targetId = PeerId.init(targetIdBytes).tryGet()
|
||||
let closerPeers = kad.rtable.findClosest(targetId.toKey(), DefaultReplic)
|
||||
let responsePb = encodeFindNodeReply(closerPeers, switch)
|
||||
await conn.writeLp(responsePb.buffer)
|
||||
|
||||
# Peer is useful. adding to rtable
|
||||
discard kad.rtable.insert(conn.peerId)
|
||||
else:
|
||||
raise newException(LPError, "unhandled kad-dht message type")
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError:
|
||||
discard
|
||||
# TODO: figure out why this fails:
|
||||
# error "could not handle request",
|
||||
# peerId = conn.PeerId, err = getCurrentExceptionMsg()
|
||||
finally:
|
||||
defer:
|
||||
await conn.close()
|
||||
while not conn.atEof:
|
||||
let buf =
|
||||
try:
|
||||
await conn.readLp(MaxMsgSize)
|
||||
except LPStreamError as e:
|
||||
debug "Read error when handling kademlia RPC", conn = conn, err = e.msg
|
||||
return
|
||||
let msg = Message.decode(buf).valueOr:
|
||||
debug "msg decode error handling kademlia RPC", err = error
|
||||
return
|
||||
|
||||
case msg.msgType
|
||||
of MessageType.findNode:
|
||||
let targetIdBytes = msg.key.valueOr:
|
||||
error "findNode message without key data present", msg = msg, conn = conn
|
||||
return
|
||||
let targetId = PeerId.init(targetIdBytes).valueOr:
|
||||
error "findNode message without valid key data", msg = msg, conn = conn
|
||||
return
|
||||
let closerPeers = kad.rtable
|
||||
.findClosest(targetId.toKey(), DefaultReplic)
|
||||
# exclude the node requester because telling a peer about itself does not reduce the distance,
|
||||
.filterIt(it != conn.peerId.toKey())
|
||||
|
||||
let responsePb = encodeFindNodeReply(closerPeers, switch)
|
||||
try:
|
||||
await conn.writeLp(responsePb.buffer)
|
||||
except LPStreamError as e:
|
||||
debug "write error when writing kad find-node RPC reply",
|
||||
conn = conn, err = e.msg
|
||||
return
|
||||
|
||||
# Peer is useful. adding to rtable
|
||||
discard kad.rtable.insert(conn.peerId)
|
||||
of MessageType.putValue:
|
||||
let record = msg.record.valueOr:
|
||||
error "no record in message buffer", msg = msg, conn = conn
|
||||
return
|
||||
let (skey, svalue) =
|
||||
if record.key.isSome() and record.value.isSome():
|
||||
(record.key.unsafeGet(), record.value.unsafeGet())
|
||||
else:
|
||||
error "no key or no value in rpc buffer", msg = msg, conn = conn
|
||||
return
|
||||
let key = EntryKey.init(skey)
|
||||
let value = EntryValue.init(svalue)
|
||||
|
||||
# Value sanitisation done. Start insertion process
|
||||
if not kad.entryValidator.isValid(key, value):
|
||||
return
|
||||
|
||||
let others =
|
||||
if kad.dataTable.entries.contains(key):
|
||||
# need to do this shenans in order to avoid exceptions.
|
||||
@[kad.dataTable.entries.getOrDefault(key)]
|
||||
else:
|
||||
@[]
|
||||
let candRec = EntryRecord.init(value, none(TimeStamp))
|
||||
let selectedRec = kad.entrySelector.select(candRec, others).valueOr:
|
||||
error "application provided selector error", msg = error, conn = conn
|
||||
return
|
||||
trace "putval handler selection",
|
||||
cand = candRec, others = others, selected = selectedRec
|
||||
|
||||
# Assume that if selection goes with another value, that it is valid
|
||||
let validated = ValidatedEntry(key: key, value: selectedRec.value)
|
||||
|
||||
kad.dataTable.insert(validated, selectedRec.time)
|
||||
# consistent with following link, echo message without change
|
||||
# https://github.com/libp2p/js-libp2p/blob/cf9aab5c841ec08bc023b9f49083c95ad78a7a07/packages/kad-dht/src/rpc/handlers/put-value.ts#L22
|
||||
try:
|
||||
await conn.writeLp(buf)
|
||||
except LPStreamError as e:
|
||||
debug "write error when writing kad find-node RPC reply",
|
||||
conn = conn, err = e.msg
|
||||
return
|
||||
else:
|
||||
error "unhandled kad-dht message type", msg = msg
|
||||
return
|
||||
return kad
|
||||
|
||||
proc setSelector*(kad: KadDHT, selector: EntrySelector) =
|
||||
doAssert(selector != nil)
|
||||
kad.entrySelector = selector
|
||||
|
||||
proc setValidator*(kad: KadDHT, validator: EntryValidator) =
|
||||
doAssert(validator != nil)
|
||||
kad.entryValidator = validator
|
||||
|
||||
method start*(
|
||||
kad: KadDHT
|
||||
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import nimcrypto/sha2
|
||||
import ../../peerid
|
||||
import ./consts
|
||||
import chronicles
|
||||
import stew/byteutils
|
||||
|
||||
type
|
||||
KeyType* {.pure.} = enum
|
||||
Unhashed
|
||||
Raw
|
||||
PeerId
|
||||
|
||||
@@ -13,15 +12,11 @@ type
|
||||
case kind*: KeyType
|
||||
of KeyType.PeerId:
|
||||
peerId*: PeerId
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
data*: array[IdLength, byte]
|
||||
of KeyType.Raw:
|
||||
data*: seq[byte]
|
||||
|
||||
proc toKey*(s: seq[byte]): Key =
|
||||
doAssert s.len == IdLength
|
||||
var data: array[IdLength, byte]
|
||||
for i in 0 ..< IdLength:
|
||||
data[i] = s[i]
|
||||
return Key(kind: KeyType.Raw, data: data)
|
||||
return Key(kind: KeyType.Raw, data: s)
|
||||
|
||||
proc toKey*(p: PeerId): Key =
|
||||
return Key(kind: KeyType.PeerId, peerId: p)
|
||||
@@ -36,7 +31,7 @@ proc getBytes*(k: Key): seq[byte] =
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
k.peerId.getBytes()
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
of KeyType.Raw:
|
||||
@(k.data)
|
||||
|
||||
template `==`*(a, b: Key): bool =
|
||||
@@ -46,7 +41,7 @@ proc shortLog*(k: Key): string =
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
"PeerId:" & $k.peerId
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
of KeyType.Raw:
|
||||
$k.kind & ":" & toHex(k.data)
|
||||
|
||||
chronicles.formatIt(Key):
|
||||
|
||||
@@ -27,7 +27,10 @@ proc alreadyInShortlist(state: LookupState, peer: Peer): bool =
|
||||
return state.shortlist.anyIt(it.peerId.getBytes() == peer.id)
|
||||
|
||||
proc updateShortlist*(
|
||||
state: var LookupState, msg: Message, onInsert: proc(p: PeerInfo) {.gcsafe.}
|
||||
state: var LookupState,
|
||||
msg: Message,
|
||||
onInsert: proc(p: PeerInfo) {.gcsafe.},
|
||||
hasher: Opt[XorDHasher],
|
||||
) =
|
||||
for newPeer in msg.closerPeers.filterIt(not alreadyInShortlist(state, it)):
|
||||
let peerInfo = PeerInfo(peerId: PeerId.init(newPeer.id).get(), addrs: newPeer.addrs)
|
||||
@@ -36,7 +39,7 @@ proc updateShortlist*(
|
||||
state.shortlist.add(
|
||||
LookupNode(
|
||||
peerId: peerInfo.peerId,
|
||||
distance: xorDistance(peerInfo.peerId, state.targetId),
|
||||
distance: xorDistance(peerInfo.peerId, state.targetId, hasher),
|
||||
queried: false,
|
||||
pending: false,
|
||||
failed: false,
|
||||
@@ -77,7 +80,12 @@ proc selectAlphaPeers*(state: LookupState): seq[PeerId] =
|
||||
break
|
||||
return selected
|
||||
|
||||
proc init*(T: type LookupState, targetId: Key, initialPeers: seq[PeerId]): T =
|
||||
proc init*(
|
||||
T: type LookupState,
|
||||
targetId: Key,
|
||||
initialPeers: seq[PeerId],
|
||||
hasher: Opt[XorDHasher],
|
||||
): T =
|
||||
var res = LookupState(
|
||||
targetId: targetId,
|
||||
shortlist: @[],
|
||||
@@ -90,7 +98,7 @@ proc init*(T: type LookupState, targetId: Key, initialPeers: seq[PeerId]): T =
|
||||
res.shortlist.add(
|
||||
LookupNode(
|
||||
peerId: p,
|
||||
distance: xorDistance(p, targetId),
|
||||
distance: xorDistance(p, targetId, hasher),
|
||||
queried: false,
|
||||
pending: false,
|
||||
failed: false,
|
||||
@@ -103,11 +111,6 @@ proc init*(T: type LookupState, targetId: Key, initialPeers: seq[PeerId]): T =
|
||||
)
|
||||
return res
|
||||
|
||||
proc checkConvergence*(state: LookupState): bool =
|
||||
let ready = state.activeQueries == 0
|
||||
let noNew = selectAlphaPeers(state).len == 0
|
||||
return ready and noNew
|
||||
|
||||
proc selectClosestK*(state: LookupState): seq[PeerId] =
|
||||
var res: seq[PeerId] = @[]
|
||||
for p in state.shortlist.filterIt(not it.failed):
|
||||
|
||||
@@ -8,6 +8,7 @@ import ./xordistance
|
||||
import ../../peerid
|
||||
import sequtils
|
||||
import ../../utils/sequninit
|
||||
import results
|
||||
|
||||
logScope:
|
||||
topics = "kad-dht rtable"
|
||||
@@ -23,15 +24,16 @@ type
|
||||
RoutingTable* = ref object
|
||||
selfId*: Key
|
||||
buckets*: seq[Bucket]
|
||||
hasher*: Opt[XorDHasher]
|
||||
|
||||
proc `$`*(rt: RoutingTable): string =
|
||||
"selfId(" & $rt.selfId & ") buckets(" & $rt.buckets & ")"
|
||||
|
||||
proc init*(T: typedesc[RoutingTable], selfId: Key): T =
|
||||
return RoutingTable(selfId: selfId, buckets: @[])
|
||||
proc init*(T: typedesc[RoutingTable], selfId: Key, hasher: Opt[XorDHasher]): T =
|
||||
return RoutingTable(selfId: selfId, buckets: @[], hasher: hasher)
|
||||
|
||||
proc bucketIndex*(selfId, key: Key): int =
|
||||
return xorDistance(selfId, key).leadingZeros
|
||||
proc bucketIndex*(selfId, key: Key, hasher: Opt[XorDHasher]): int =
|
||||
return xorDistance(selfId, key, hasher).leadingZeros
|
||||
|
||||
proc peerIndexInBucket(bucket: var Bucket, nodeId: Key): Opt[int] =
|
||||
for i, p in bucket.peers:
|
||||
@@ -43,7 +45,7 @@ proc insert*(rtable: var RoutingTable, nodeId: Key): bool =
|
||||
if nodeId == rtable.selfId:
|
||||
return false # No self insertion
|
||||
|
||||
let idx = bucketIndex(rtable.selfId, nodeId)
|
||||
let idx = bucketIndex(rtable.selfId, nodeId, rtable.hasher)
|
||||
if idx >= maxBuckets:
|
||||
trace "cannot insert node. max buckets have been reached",
|
||||
nodeId, bucketIdx = idx, maxBuckets
|
||||
@@ -80,7 +82,9 @@ proc findClosest*(rtable: RoutingTable, targetId: Key, count: int): seq[Key] =
|
||||
|
||||
allNodes.sort(
|
||||
proc(a, b: Key): int =
|
||||
cmp(xorDistance(a, targetId), xorDistance(b, targetId))
|
||||
cmp(
|
||||
xorDistance(a, targetId, rtable.hasher), xorDistance(b, targetId, rtable.hasher)
|
||||
)
|
||||
)
|
||||
|
||||
return allNodes[0 ..< min(count, allNodes.len)]
|
||||
|
||||
@@ -1,9 +1,27 @@
|
||||
import ./consts
|
||||
import stew/arrayOps
|
||||
import ./keys
|
||||
import nimcrypto/sha2
|
||||
import ../../peerid
|
||||
import results
|
||||
|
||||
type XorDistance* = array[IdLength, byte]
|
||||
type XorDHasher* = proc(input: seq[byte]): array[IdLength, byte] {.
|
||||
raises: [], nimcall, noSideEffect, gcsafe
|
||||
.}
|
||||
|
||||
proc defaultHasher(
|
||||
input: seq[byte]
|
||||
): array[IdLength, byte] {.raises: [], nimcall, noSideEffect, gcsafe.} =
|
||||
return sha256.digest(input).data
|
||||
|
||||
# useful for testing purposes
|
||||
proc noOpHasher*(
|
||||
input: seq[byte]
|
||||
): array[IdLength, byte] {.raises: [], nimcall, noSideEffect, gcsafe.} =
|
||||
var data: array[IdLength, byte]
|
||||
discard data.copyFrom(input)
|
||||
return data
|
||||
|
||||
proc countLeadingZeroBits*(b: byte): int =
|
||||
for i in 0 .. 7:
|
||||
@@ -31,25 +49,23 @@ proc `<`*(a, b: XorDistance): bool =
|
||||
proc `<=`*(a, b: XorDistance): bool =
|
||||
cmp(a, b) <= 0
|
||||
|
||||
proc hashFor(k: Key): seq[byte] =
|
||||
proc hashFor(k: Key, hasher: Opt[XorDHasher]): seq[byte] =
|
||||
return
|
||||
@(
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
sha256.digest(k.peerId.getBytes()).data
|
||||
hasher.get(defaultHasher)(k.peerId.getBytes())
|
||||
of KeyType.Raw:
|
||||
sha256.digest(k.data).data
|
||||
of KeyType.Unhashed:
|
||||
k.data
|
||||
hasher.get(defaultHasher)(k.data)
|
||||
)
|
||||
|
||||
proc xorDistance*(a, b: Key): XorDistance =
|
||||
let hashA = a.hashFor()
|
||||
let hashB = b.hashFor()
|
||||
proc xorDistance*(a, b: Key, hasher: Opt[XorDHasher]): XorDistance =
|
||||
let hashA = a.hashFor(hasher)
|
||||
let hashB = b.hashFor(hasher)
|
||||
var response: XorDistance
|
||||
for i in 0 ..< hashA.len:
|
||||
response[i] = hashA[i] xor hashB[i]
|
||||
return response
|
||||
|
||||
proc xorDistance*(a: PeerId, b: Key): XorDistance =
|
||||
xorDistance(a.toKey(), b)
|
||||
proc xorDistance*(a: PeerId, b: Key, hasher: Opt[XorDHasher]): XorDistance =
|
||||
xorDistance(a.toKey(), b, hasher)
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
import chronos, chronicles, sequtils
|
||||
import stew/endians2
|
||||
import ./core, ../../stream/connection
|
||||
when defined(libp2p_quic_support):
|
||||
import ../../transports/quictransport
|
||||
|
||||
logScope:
|
||||
topics = "libp2p perf"
|
||||
@@ -59,13 +57,8 @@ proc perf*(
|
||||
statsCopy.uploadBytes += toWrite.uint
|
||||
p.stats = statsCopy
|
||||
|
||||
# Close connection after writing for TCP, but not for QUIC
|
||||
when defined(libp2p_quic_support):
|
||||
if not (conn of QuicStream):
|
||||
await conn.close()
|
||||
# For QUIC streams, don't close yet - let server manage lifecycle
|
||||
else:
|
||||
await conn.close()
|
||||
# Close write side of the stream (half-close) to signal EOF to server
|
||||
await conn.closeWrite()
|
||||
|
||||
size = sizeToRead
|
||||
|
||||
@@ -80,10 +73,8 @@ proc perf*(
|
||||
statsCopy.downloadBytes += toRead.uint
|
||||
p.stats = statsCopy
|
||||
|
||||
# Close QUIC connections after read phase
|
||||
when defined(libp2p_quic_support):
|
||||
if conn of QuicStream:
|
||||
await conn.close()
|
||||
# Close the connection after reading
|
||||
await conn.close()
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
except LPStreamError as e:
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
import chronos, chronicles
|
||||
import stew/endians2
|
||||
import ./core, ../protocol, ../../stream/connection, ../../utility
|
||||
when defined(libp2p_quic_support):
|
||||
import ../../transports/quictransport
|
||||
|
||||
export chronicles, connection
|
||||
|
||||
@@ -26,50 +24,29 @@ type Perf* = ref object of LPProtocol
|
||||
|
||||
proc new*(T: typedesc[Perf]): T {.public.} =
|
||||
var p = T()
|
||||
|
||||
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
|
||||
var bytesRead = 0
|
||||
try:
|
||||
trace "Received benchmark performance check", conn
|
||||
var
|
||||
sizeBuffer: array[8, byte]
|
||||
size: uint64
|
||||
await conn.readExactly(addr sizeBuffer[0], 8)
|
||||
size = uint64.fromBytesBE(sizeBuffer)
|
||||
|
||||
var toReadBuffer: array[PerfSize, byte]
|
||||
try:
|
||||
# Different handling for QUIC vs TCP streams
|
||||
when defined(libp2p_quic_support):
|
||||
if conn of QuicStream:
|
||||
# QUIC needs timeout-based approach to detect end of upload
|
||||
while not conn.atEof:
|
||||
let readFut = conn.readOnce(addr toReadBuffer[0], PerfSize)
|
||||
let read = readFut.read()
|
||||
if read == 0:
|
||||
break
|
||||
bytesRead += read
|
||||
else:
|
||||
# TCP streams handle EOF properly
|
||||
while true:
|
||||
let read = await conn.readOnce(addr toReadBuffer[0], PerfSize)
|
||||
if read == 0:
|
||||
break
|
||||
bytesRead += read
|
||||
else:
|
||||
# TCP streams handle EOF properly
|
||||
while true:
|
||||
let read = await conn.readOnce(addr toReadBuffer[0], PerfSize)
|
||||
if read == 0:
|
||||
break
|
||||
bytesRead += read
|
||||
except CatchableError:
|
||||
discard
|
||||
var uploadSizeBuffer: array[8, byte]
|
||||
await conn.readExactly(addr uploadSizeBuffer[0], 8)
|
||||
var uploadSize = uint64.fromBytesBE(uploadSizeBuffer)
|
||||
|
||||
var buf: array[PerfSize, byte]
|
||||
while size > 0:
|
||||
let toWrite = min(size, PerfSize)
|
||||
await conn.write(buf[0 ..< toWrite])
|
||||
size -= toWrite
|
||||
var readBuffer: array[PerfSize, byte]
|
||||
while not conn.atEof:
|
||||
try:
|
||||
let readBytes = await conn.readOnce(addr readBuffer[0], PerfSize)
|
||||
if readBytes == 0:
|
||||
break
|
||||
except LPStreamEOFError:
|
||||
break
|
||||
|
||||
var writeBuffer: array[PerfSize, byte]
|
||||
while uploadSize > 0:
|
||||
let toWrite = min(uploadSize, PerfSize)
|
||||
await conn.write(writeBuffer[0 ..< toWrite])
|
||||
uploadSize -= toWrite
|
||||
except CancelledError as exc:
|
||||
trace "cancelled perf handler"
|
||||
raise exc
|
||||
|
||||
@@ -725,9 +725,8 @@ method rpcHandler*(
|
||||
continue
|
||||
|
||||
if (msg.signature.len > 0 or g.verifySignature) and not msg.verify():
|
||||
# always validate if signature is present or required
|
||||
debug "Dropping message due to failed signature verification",
|
||||
msgId = shortLog(msgId), peer
|
||||
debug "Dropping message due to failed signature verification", msg = msg
|
||||
|
||||
await g.punishInvalidMessage(peer, msg)
|
||||
continue
|
||||
|
||||
|
||||
@@ -478,19 +478,16 @@ iterator splitRPCMsg(
|
||||
## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send.
|
||||
## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes.
|
||||
|
||||
var currentRPCMsg = rpcMsg
|
||||
currentRPCMsg.messages = newSeq[Message]()
|
||||
|
||||
var currentSize = byteSize(currentRPCMsg)
|
||||
var currentRPCMsg = RPCMsg()
|
||||
var currentSize = 0
|
||||
|
||||
for msg in rpcMsg.messages:
|
||||
let msgSize = byteSize(msg)
|
||||
|
||||
# Check if adding the next message will exceed maxSize
|
||||
if float(currentSize + msgSize) * 1.1 > float(maxSize):
|
||||
# Guessing 10% protobuf overhead
|
||||
if currentRPCMsg.messages.len == 0:
|
||||
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
if currentSize + msgSize > maxSize:
|
||||
if msgSize > maxSize:
|
||||
warn "message too big to sent", peer, rpcMsg = shortLog(msg)
|
||||
continue # Skip this message
|
||||
|
||||
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
@@ -502,11 +499,9 @@ iterator splitRPCMsg(
|
||||
currentSize += msgSize
|
||||
|
||||
# Check if there is a non-empty currentRPCMsg left to be added
|
||||
if currentSize > 0 and currentRPCMsg.messages.len > 0:
|
||||
if currentRPCMsg.messages.len > 0:
|
||||
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
yield encodeRpcMsg(currentRPCMsg, anonymize)
|
||||
else:
|
||||
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
|
||||
proc send*(
|
||||
p: PubSubPeer,
|
||||
@@ -542,13 +537,8 @@ proc send*(
|
||||
sendMetrics(msg)
|
||||
encodeRpcMsg(msg, anonymize)
|
||||
|
||||
if encoded.len > p.maxMessageSize and msg.messages.len > 1:
|
||||
for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize):
|
||||
asyncSpawn p.sendEncoded(encodedSplitMsg, isHighPriority, useCustomConn)
|
||||
else:
|
||||
# If the message size is within limits, send it as is
|
||||
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
|
||||
asyncSpawn p.sendEncoded(encoded, isHighPriority, useCustomConn)
|
||||
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
|
||||
asyncSpawn p.sendEncoded(encoded, isHighPriority, useCustomConn)
|
||||
|
||||
proc canAskIWant*(p: PubSubPeer, msgId: MessageId): bool =
|
||||
for sentIHave in p.sentIHaves.mitems():
|
||||
|
||||
@@ -41,15 +41,36 @@ func defaultMsgIdProvider*(m: Message): Result[MessageId, ValidationResult] =
|
||||
proc sign*(msg: Message, privateKey: PrivateKey): CryptoResult[seq[byte]] =
|
||||
ok((?privateKey.sign(PubSubPrefix & encodeMessage(msg, false))).getBytes())
|
||||
|
||||
proc extractPublicKey(m: Message): Opt[PublicKey] =
|
||||
var pubkey: PublicKey
|
||||
if m.fromPeer.hasPublicKey() and m.fromPeer.extractPublicKey(pubkey):
|
||||
Opt.some(pubkey)
|
||||
elif m.key.len > 0 and pubkey.init(m.key):
|
||||
# check if peerId extracted from m.key is the same as m.fromPeer
|
||||
let derivedPeerId = PeerId.init(pubkey).valueOr:
|
||||
warn "could not derive peerId from key field"
|
||||
return Opt.none(PublicKey)
|
||||
|
||||
if derivedPeerId != m.fromPeer:
|
||||
warn "peerId derived from msg.key is not the same as msg.fromPeer",
|
||||
derivedPeerId = derivedPeerId, fromPeer = m.fromPeer
|
||||
return Opt.none(PublicKey)
|
||||
Opt.some(pubkey)
|
||||
else:
|
||||
Opt.none(PublicKey)
|
||||
|
||||
proc verify*(m: Message): bool =
|
||||
if m.signature.len > 0 and m.key.len > 0:
|
||||
if m.signature.len > 0:
|
||||
var msg = m
|
||||
msg.signature = @[]
|
||||
msg.key = @[]
|
||||
|
||||
var remote: Signature
|
||||
var key: PublicKey
|
||||
if remote.init(m.signature) and key.init(m.key):
|
||||
let key = m.extractPublicKey().valueOr:
|
||||
warn "could not extract public key", msg = m
|
||||
return false
|
||||
|
||||
if remote.init(m.signature):
|
||||
trace "verifying signature", remoteSignature = remote
|
||||
result = remote.verify(PubSubPrefix & encodeMessage(msg, false), key)
|
||||
|
||||
|
||||
@@ -35,12 +35,15 @@ declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
|
||||
|
||||
const
|
||||
RendezVousCodec* = "/rendezvous/1.0.0"
|
||||
# Default minimum TTL per libp2p spec
|
||||
MinimumDuration* = 2.hours
|
||||
# Lower validation limit to accommodate Waku requirements
|
||||
MinimumAcceptedDuration = 1.minutes
|
||||
MaximumDuration = 72.hours
|
||||
MaximumMessageLen = 1 shl 22 # 4MB
|
||||
MinimumNamespaceLen = 1
|
||||
MaximumNamespaceLen = 255
|
||||
RegistrationLimitPerPeer = 1000
|
||||
RegistrationLimitPerPeer* = 1000
|
||||
DiscoverLimit = 1000'u64
|
||||
SemaphoreDefaultSize = 5
|
||||
|
||||
@@ -69,7 +72,7 @@ type
|
||||
Register = object
|
||||
ns: string
|
||||
signedPeerRecord: seq[byte]
|
||||
ttl: Opt[uint64] # in seconds
|
||||
ttl*: Opt[uint64] # in seconds
|
||||
|
||||
RegisterResponse = object
|
||||
status: ResponseStatus
|
||||
@@ -310,27 +313,27 @@ proc decode(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
|
||||
type
|
||||
RendezVousError* = object of DiscoveryError
|
||||
RegisteredData = object
|
||||
expiration: Moment
|
||||
peerId: PeerId
|
||||
data: Register
|
||||
expiration*: Moment
|
||||
peerId*: PeerId
|
||||
data*: Register
|
||||
|
||||
RendezVous* = ref object of LPProtocol
|
||||
# Registered needs to be an offsetted sequence
|
||||
# because we need stable index for the cookies.
|
||||
registered: OffsettedSeq[RegisteredData]
|
||||
registered*: OffsettedSeq[RegisteredData]
|
||||
# Namespaces is a table whose key is a salted namespace and
|
||||
# the value is the index sequence corresponding to this
|
||||
# namespace in the offsettedqueue.
|
||||
namespaces: Table[string, seq[int]]
|
||||
namespaces*: Table[string, seq[int]]
|
||||
rng: ref HmacDrbgContext
|
||||
salt: string
|
||||
defaultDT: Moment
|
||||
expiredDT: Moment
|
||||
registerDeletionLoop: Future[void]
|
||||
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
|
||||
# + make the heartbeat sleep duration "smarter"
|
||||
sema: AsyncSemaphore
|
||||
peers: seq[PeerId]
|
||||
cookiesSaved: Table[PeerId, Table[string, seq[byte]]]
|
||||
cookiesSaved*: Table[PeerId, Table[string, seq[byte]]]
|
||||
switch: Switch
|
||||
minDuration: Duration
|
||||
maxDuration: Duration
|
||||
@@ -394,9 +397,8 @@ proc sendDiscoverResponseError(
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
|
||||
let n = Moment.now()
|
||||
for data in rdv.registered:
|
||||
if data.peerId == peerId and data.expiration > n:
|
||||
if data.peerId == peerId:
|
||||
result.inc()
|
||||
|
||||
proc save(
|
||||
@@ -409,7 +411,7 @@ proc save(
|
||||
if rdv.registered[index].peerId == peerId:
|
||||
if update == false:
|
||||
return
|
||||
rdv.registered[index].expiration = rdv.defaultDT
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
rdv.registered.add(
|
||||
RegisteredData(
|
||||
peerId: peerId,
|
||||
@@ -446,7 +448,7 @@ proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == conn.peerId:
|
||||
rdv.registered[index].expiration = rdv.defaultDT
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
libp2p_rendezvous_registered.dec()
|
||||
except KeyError:
|
||||
return
|
||||
@@ -468,11 +470,17 @@ proc discover(
|
||||
await conn.sendDiscoverResponseError(InvalidCookie)
|
||||
return
|
||||
else:
|
||||
Cookie(offset: rdv.registered.low().uint64 - 1)
|
||||
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get() or
|
||||
cookie.offset < rdv.registered.low().uint64 or
|
||||
cookie.offset > rdv.registered.high().uint64:
|
||||
cookie = Cookie(offset: rdv.registered.low().uint64 - 1)
|
||||
# Start from the current lowest index (inclusive)
|
||||
Cookie(offset: rdv.registered.low().uint64)
|
||||
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get():
|
||||
# Namespace changed: start from the beginning of that namespace
|
||||
cookie = Cookie(offset: rdv.registered.low().uint64)
|
||||
elif cookie.offset < rdv.registered.low().uint64:
|
||||
# Cookie behind available range: reset to current low
|
||||
cookie.offset = rdv.registered.low().uint64
|
||||
elif cookie.offset > (rdv.registered.high() + 1).uint64:
|
||||
# Cookie ahead of available range: reset to one past current high (empty page)
|
||||
cookie.offset = (rdv.registered.high() + 1).uint64
|
||||
let namespaces =
|
||||
if d.ns.isSome():
|
||||
try:
|
||||
@@ -485,21 +493,21 @@ proc discover(
|
||||
if namespaces.len() == 0:
|
||||
await conn.sendDiscoverResponse(@[], Cookie())
|
||||
return
|
||||
var offset = namespaces[^1]
|
||||
var nextOffset = cookie.offset
|
||||
let n = Moment.now()
|
||||
var s = collect(newSeq()):
|
||||
for index in namespaces:
|
||||
var reg = rdv.registered[index]
|
||||
if limit == 0:
|
||||
offset = index
|
||||
break
|
||||
if reg.expiration < n or index.uint64 <= cookie.offset:
|
||||
if reg.expiration < n or index.uint64 < cookie.offset:
|
||||
continue
|
||||
limit.dec()
|
||||
nextOffset = index.uint64 + 1
|
||||
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
|
||||
reg.data
|
||||
rdv.rng.shuffle(s)
|
||||
await conn.sendDiscoverResponse(s, Cookie(offset: offset.uint64, ns: d.ns))
|
||||
await conn.sendDiscoverResponse(s, Cookie(offset: nextOffset, ns: d.ns))
|
||||
|
||||
proc advertisePeer(
|
||||
rdv: RendezVous, peer: PeerId, msg: seq[byte]
|
||||
@@ -683,7 +691,7 @@ proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
|
||||
rdv.registered[index].expiration = rdv.defaultDT
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
@@ -738,10 +746,10 @@ proc new*(
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
if minDuration < 1.minutes:
|
||||
if minDuration < MinimumAcceptedDuration:
|
||||
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
|
||||
|
||||
if maxDuration > 72.hours:
|
||||
if maxDuration > MaximumDuration:
|
||||
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
|
||||
|
||||
if minDuration >= maxDuration:
|
||||
@@ -754,8 +762,8 @@ proc new*(
|
||||
let rdv = T(
|
||||
rng: rng,
|
||||
salt: string.fromBytes(generateBytes(rng[], 8)),
|
||||
registered: initOffsettedSeq[RegisteredData](1),
|
||||
defaultDT: Moment.now() - 1.days,
|
||||
registered: initOffsettedSeq[RegisteredData](),
|
||||
expiredDT: Moment.now() - 1.days,
|
||||
#registerEvent: newAsyncEvent(),
|
||||
sema: newAsyncSemaphore(SemaphoreDefaultSize),
|
||||
minDuration: minDuration,
|
||||
@@ -806,7 +814,7 @@ proc new*(
|
||||
rdv.setup(switch)
|
||||
return rdv
|
||||
|
||||
proc deletesRegister(
|
||||
proc deletesRegister*(
|
||||
rdv: RendezVous, interval = 1.minutes
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "Register timeout", interval:
|
||||
|
||||
@@ -583,7 +583,8 @@ method handshake*(
|
||||
)
|
||||
conn.peerId = pid
|
||||
|
||||
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
|
||||
var tmp =
|
||||
NoiseConnection.new(conn, conn.peerId, conn.observedAddr, conn.localAddr)
|
||||
if initiator:
|
||||
tmp.readCs = handshakeRes.cs2
|
||||
tmp.writeCs = handshakeRes.cs1
|
||||
|
||||
@@ -51,12 +51,14 @@ proc new*(
|
||||
conn: Connection,
|
||||
peerId: PeerId,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
timeout: Duration = DefaultConnectionTimeout,
|
||||
): T =
|
||||
result = T(
|
||||
stream: conn,
|
||||
peerId: peerId,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
closeEvent: conn.closeEvent,
|
||||
timeout: timeout,
|
||||
dir: conn.dir,
|
||||
|
||||
@@ -62,8 +62,15 @@ proc init*(
|
||||
dir: Direction,
|
||||
timeout = DefaultChronosStreamTimeout,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
): ChronosStream =
|
||||
result = C(client: client, timeout: timeout, dir: dir, observedAddr: observedAddr)
|
||||
result = C(
|
||||
client: client,
|
||||
timeout: timeout,
|
||||
dir: dir,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
)
|
||||
result.initStream()
|
||||
|
||||
template withExceptions(body: untyped) =
|
||||
@@ -151,6 +158,19 @@ method closed*(s: ChronosStream): bool =
|
||||
method atEof*(s: ChronosStream): bool =
|
||||
s.client.atEof()
|
||||
|
||||
method closeWrite*(s: ChronosStream) {.async: (raises: []).} =
|
||||
## Close the write side of the TCP connection using half-close
|
||||
if not s.client.closed():
|
||||
try:
|
||||
await s.client.shutdownWait()
|
||||
trace "Write side closed", address = $s.client.remoteAddress(), s
|
||||
except TransportError:
|
||||
# Ignore transport errors during shutdown
|
||||
discard
|
||||
except CatchableError:
|
||||
# Ignore other errors during shutdown
|
||||
discard
|
||||
|
||||
method closeImpl*(s: ChronosStream) {.async: (raises: []).} =
|
||||
trace "Shutting down chronos stream", address = $s.client.remoteAddress(), s
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ type
|
||||
timeoutHandler*: TimeoutHandler # timeout handler
|
||||
peerId*: PeerId
|
||||
observedAddr*: Opt[MultiAddress]
|
||||
localAddr*: Opt[MultiAddress]
|
||||
protocol*: string # protocol used by the connection, used as metrics tag
|
||||
transportDir*: Direction # underlying transport (usually socket) direction
|
||||
when defined(libp2p_agents_metrics):
|
||||
@@ -40,6 +41,12 @@ type
|
||||
|
||||
proc timeoutMonitor(s: Connection) {.async: (raises: []).}
|
||||
|
||||
method closeWrite*(s: Connection): Future[void] {.base, async: (raises: []).} =
|
||||
## Close the write side of the connection
|
||||
## Subclasses should implement this for their specific transport
|
||||
## Default implementation just closes the entire connection
|
||||
await s.close()
|
||||
|
||||
func shortLog*(conn: Connection): string =
|
||||
try:
|
||||
if conn == nil:
|
||||
@@ -133,13 +140,17 @@ when defined(libp2p_agents_metrics):
|
||||
var conn = s
|
||||
while conn != nil:
|
||||
conn.shortAgent = shortAgent
|
||||
conn = conn.getWrapped()
|
||||
let wrapped = conn.getWrapped()
|
||||
if wrapped == conn:
|
||||
break
|
||||
conn = wrapped
|
||||
|
||||
proc new*(
|
||||
C: type Connection,
|
||||
peerId: PeerId,
|
||||
dir: Direction,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
observedAddr: Opt[MultiAddress] = Opt.none(MultiAddress),
|
||||
localAddr: Opt[MultiAddress] = Opt.none(MultiAddress),
|
||||
timeout: Duration = DefaultConnectionTimeout,
|
||||
timeoutHandler: TimeoutHandler = nil,
|
||||
): Connection =
|
||||
@@ -149,6 +160,7 @@ proc new*(
|
||||
timeout: timeout,
|
||||
timeoutHandler: timeoutHandler,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
)
|
||||
|
||||
result.initStream()
|
||||
|
||||
@@ -36,14 +36,19 @@ type QuicStream* = ref object of P2PConnection
|
||||
cached: seq[byte]
|
||||
|
||||
proc new(
|
||||
_: type QuicStream, stream: Stream, oaddr: Opt[MultiAddress], peerId: PeerId
|
||||
_: type QuicStream,
|
||||
stream: Stream,
|
||||
oaddr: Opt[MultiAddress],
|
||||
laddr: Opt[MultiAddress],
|
||||
peerId: PeerId,
|
||||
): QuicStream =
|
||||
let quicstream = QuicStream(stream: stream, observedAddr: oaddr, peerId: peerId)
|
||||
let quicstream =
|
||||
QuicStream(stream: stream, observedAddr: oaddr, localAddr: laddr, peerId: peerId)
|
||||
procCall P2PConnection(quicstream).initStream()
|
||||
quicstream
|
||||
|
||||
method getWrapped*(self: QuicStream): P2PConnection =
|
||||
nil
|
||||
self
|
||||
|
||||
template mapExceptions(body: untyped) =
|
||||
try:
|
||||
@@ -56,18 +61,23 @@ template mapExceptions(body: untyped) =
|
||||
method readOnce*(
|
||||
stream: QuicStream, pbytes: pointer, nbytes: int
|
||||
): Future[int] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
try:
|
||||
if stream.cached.len == 0:
|
||||
if stream.cached.len == 0:
|
||||
try:
|
||||
stream.cached = await stream.stream.read()
|
||||
if stream.cached.len == 0:
|
||||
raise newLPStreamEOFError()
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except LPStreamEOFError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise (ref LPStreamError)(msg: "error in readOnce: " & exc.msg, parent: exc)
|
||||
|
||||
result = min(nbytes, stream.cached.len)
|
||||
copyMem(pbytes, addr stream.cached[0], result)
|
||||
stream.cached = stream.cached[result ..^ 1]
|
||||
libp2p_network_bytes.inc(result.int64, labelValues = ["in"])
|
||||
except CatchableError as exc:
|
||||
raise newLPStreamEOFError()
|
||||
let toRead = min(nbytes, stream.cached.len)
|
||||
copyMem(pbytes, addr stream.cached[0], toRead)
|
||||
stream.cached = stream.cached[toRead ..^ 1]
|
||||
libp2p_network_bytes.inc(toRead.int64, labelValues = ["in"])
|
||||
return toRead
|
||||
|
||||
{.push warning[LockLevel]: off.}
|
||||
method write*(
|
||||
@@ -78,6 +88,13 @@ method write*(
|
||||
|
||||
{.pop.}
|
||||
|
||||
method closeWrite*(stream: QuicStream) {.async: (raises: []).} =
|
||||
## Close the write side of the QUIC stream
|
||||
try:
|
||||
await stream.stream.closeWrite()
|
||||
except CatchableError as exc:
|
||||
discard
|
||||
|
||||
method closeImpl*(stream: QuicStream) {.async: (raises: []).} =
|
||||
try:
|
||||
await stream.stream.close()
|
||||
@@ -108,7 +125,11 @@ proc getStream*(
|
||||
stream = await session.connection.openStream()
|
||||
await stream.write(@[]) # QUIC streams do not exist until data is sent
|
||||
|
||||
let qs = QuicStream.new(stream, session.observedAddr, session.peerId)
|
||||
let qs =
|
||||
QuicStream.new(stream, session.observedAddr, session.localAddr, session.peerId)
|
||||
when defined(libp2p_agents_metrics):
|
||||
qs.shortAgent = session.shortAgent
|
||||
|
||||
session.streams.add(qs)
|
||||
return qs
|
||||
except CatchableError as exc:
|
||||
@@ -116,13 +137,20 @@ proc getStream*(
|
||||
raise (ref QuicTransportError)(msg: "error in getStream: " & exc.msg, parent: exc)
|
||||
|
||||
method getWrapped*(self: QuicSession): P2PConnection =
|
||||
nil
|
||||
self
|
||||
|
||||
# Muxer
|
||||
type QuicMuxer = ref object of Muxer
|
||||
quicSession: QuicSession
|
||||
handleFut: Future[void]
|
||||
|
||||
when defined(libp2p_agents_metrics):
|
||||
method setShortAgent*(m: QuicMuxer, shortAgent: string) =
|
||||
m.quicSession.shortAgent = shortAgent
|
||||
for s in m.quicSession.streams:
|
||||
s.shortAgent = shortAgent
|
||||
m.connection.shortAgent = shortAgent
|
||||
|
||||
method newStream*(
|
||||
m: QuicMuxer, name: string = "", lazy: bool = false
|
||||
): Future[P2PConnection] {.
|
||||
@@ -141,7 +169,7 @@ proc handleStream(m: QuicMuxer, chann: QuicStream) {.async: (raises: []).} =
|
||||
trace "finished handling stream"
|
||||
doAssert(chann.closed, "connection not closed by handler!")
|
||||
except CatchableError as exc:
|
||||
trace "Exception in mplex stream handler", msg = exc.msg
|
||||
trace "Exception in quic stream handler", msg = exc.msg
|
||||
await chann.close()
|
||||
|
||||
method handle*(m: QuicMuxer): Future[void] {.async: (raises: []).} =
|
||||
@@ -150,7 +178,7 @@ method handle*(m: QuicMuxer): Future[void] {.async: (raises: []).} =
|
||||
let incomingStream = await m.quicSession.getStream(Direction.In)
|
||||
asyncSpawn m.handleStream(incomingStream)
|
||||
except CatchableError as exc:
|
||||
trace "Exception in mplex handler", msg = exc.msg
|
||||
trace "Exception in quic handler", msg = exc.msg
|
||||
|
||||
method close*(m: QuicMuxer) {.async: (raises: []).} =
|
||||
try:
|
||||
@@ -262,7 +290,8 @@ method start*(
|
||||
|
||||
method stop*(transport: QuicTransport) {.async: (raises: []).} =
|
||||
if transport.running:
|
||||
for c in transport.connections:
|
||||
let conns = transport.connections[0 .. ^1]
|
||||
for c in conns:
|
||||
await c.close()
|
||||
await procCall Transport(transport).stop()
|
||||
try:
|
||||
@@ -277,11 +306,17 @@ proc wrapConnection(
|
||||
transport: QuicTransport, connection: QuicConnection
|
||||
): QuicSession {.raises: [TransportOsError, MaError].} =
|
||||
let
|
||||
remoteAddr = connection.remoteAddress()
|
||||
observedAddr =
|
||||
MultiAddress.init(remoteAddr, IPPROTO_UDP).get() &
|
||||
MultiAddress.init(connection.remoteAddress(), IPPROTO_UDP).get() &
|
||||
MultiAddress.init("/quic-v1").get()
|
||||
session = QuicSession(connection: connection, observedAddr: Opt.some(observedAddr))
|
||||
localAddr =
|
||||
MultiAddress.init(connection.localAddress(), IPPROTO_UDP).get() &
|
||||
MultiAddress.init("/quic-v1").get()
|
||||
session = QuicSession(
|
||||
connection: connection,
|
||||
observedAddr: Opt.some(observedAddr),
|
||||
localAddr: Opt.some(localAddr),
|
||||
)
|
||||
|
||||
session.initStream()
|
||||
|
||||
@@ -301,12 +336,12 @@ method accept*(
|
||||
): Future[connection.Connection] {.
|
||||
async: (raises: [transport.TransportError, CancelledError])
|
||||
.} =
|
||||
doAssert not self.listener.isNil, "call start() before calling accept()"
|
||||
|
||||
if not self.running:
|
||||
# stop accept only when transport is stopped (not when error occurs)
|
||||
raise newException(QuicTransportAcceptStopped, "Quic transport stopped")
|
||||
|
||||
doAssert not self.listener.isNil, "call start() before calling accept()"
|
||||
|
||||
try:
|
||||
let connection = await self.listener.accept()
|
||||
return self.wrapConnection(connection)
|
||||
|
||||
@@ -47,6 +47,7 @@ proc connHandler*(
|
||||
self: TcpTransport,
|
||||
client: StreamTransport,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
dir: Direction,
|
||||
): Connection =
|
||||
trace "Handling tcp connection",
|
||||
@@ -59,6 +60,7 @@ proc connHandler*(
|
||||
client = client,
|
||||
dir = dir,
|
||||
observedAddr = observedAddr,
|
||||
localAddr = localAddr,
|
||||
timeout = self.connectionsTimeout,
|
||||
)
|
||||
)
|
||||
@@ -267,18 +269,22 @@ method accept*(
|
||||
safeCloseWait(transp)
|
||||
raise newTransportClosedError()
|
||||
|
||||
let remote =
|
||||
let (localAddr, observedAddr) =
|
||||
try:
|
||||
transp.remoteAddress
|
||||
(
|
||||
MultiAddress.init(transp.localAddress).expect(
|
||||
"Can initialize from local address"
|
||||
),
|
||||
MultiAddress.init(transp.remoteAddress).expect(
|
||||
"Can initialize from remote address"
|
||||
),
|
||||
)
|
||||
except TransportOsError as exc:
|
||||
# The connection had errors / was closed before `await` returned control
|
||||
safeCloseWait(transp)
|
||||
debug "Cannot read remote address", description = exc.msg
|
||||
debug "Cannot read address", description = exc.msg
|
||||
return nil
|
||||
|
||||
let observedAddr =
|
||||
MultiAddress.init(remote).expect("Can initialize from remote address")
|
||||
self.connHandler(transp, Opt.some(observedAddr), Direction.In)
|
||||
self.connHandler(transp, Opt.some(observedAddr), Opt.some(localAddr), Direction.In)
|
||||
|
||||
method dial*(
|
||||
self: TcpTransport,
|
||||
@@ -320,14 +326,17 @@ method dial*(
|
||||
safeCloseWait(transp)
|
||||
raise newTransportClosedError()
|
||||
|
||||
let observedAddr =
|
||||
let (observedAddr, localAddr) =
|
||||
try:
|
||||
MultiAddress.init(transp.remoteAddress).expect("remote address is valid")
|
||||
(
|
||||
MultiAddress.init(transp.remoteAddress).expect("remote address is valid"),
|
||||
MultiAddress.init(transp.localAddress).expect("local address is valid"),
|
||||
)
|
||||
except TransportOsError as exc:
|
||||
safeCloseWait(transp)
|
||||
raise (ref TcpTransportError)(msg: "MultiAddress.init error in dial: " & exc.msg)
|
||||
|
||||
self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
|
||||
self.connHandler(transp, Opt.some(observedAddr), Opt.some(localAddr), Direction.Out)
|
||||
|
||||
method handles*(t: TcpTransport, address: MultiAddress): bool {.raises: [].} =
|
||||
if procCall Transport(t).handles(address):
|
||||
|
||||
@@ -237,7 +237,9 @@ method dial*(
|
||||
try:
|
||||
transp = await connectToTorServer(self.transportAddress)
|
||||
await dialPeer(transp, address)
|
||||
return self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
|
||||
return self.tcpTransport.connHandler(
|
||||
transp, Opt.none(MultiAddress), Opt.none(MultiAddress), Direction.Out
|
||||
)
|
||||
except CancelledError as e:
|
||||
safeCloseWait(transp)
|
||||
raise e
|
||||
|
||||
@@ -18,9 +18,9 @@ import
|
||||
../multicodec,
|
||||
../muxers/muxer,
|
||||
../upgrademngrs/upgrade,
|
||||
../protocols/connectivity/autonat/core
|
||||
../protocols/connectivity/autonat/types
|
||||
|
||||
export core.NetworkReachability
|
||||
export types.NetworkReachability
|
||||
|
||||
logScope:
|
||||
topics = "libp2p transport"
|
||||
|
||||
@@ -55,10 +55,16 @@ proc new*(
|
||||
session: WSSession,
|
||||
dir: Direction,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
timeout = 10.minutes,
|
||||
): T =
|
||||
let stream =
|
||||
T(session: session, timeout: timeout, dir: dir, observedAddr: observedAddr)
|
||||
let stream = T(
|
||||
session: session,
|
||||
timeout: timeout,
|
||||
dir: dir,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
)
|
||||
|
||||
stream.initStream()
|
||||
return stream
|
||||
@@ -212,11 +218,9 @@ method stop*(self: WsTransport) {.async: (raises: []).} =
|
||||
trace "Stopping WS transport"
|
||||
await procCall Transport(self).stop() # call base
|
||||
|
||||
checkFutures(
|
||||
await allFinished(
|
||||
self.connections[Direction.In].mapIt(it.close()) &
|
||||
self.connections[Direction.Out].mapIt(it.close())
|
||||
)
|
||||
discard await allFinished(
|
||||
self.connections[Direction.In].mapIt(it.close()) &
|
||||
self.connections[Direction.Out].mapIt(it.close())
|
||||
)
|
||||
|
||||
var toWait: seq[Future[void]]
|
||||
@@ -241,9 +245,8 @@ proc connHandler(
|
||||
self: WsTransport, stream: WSSession, secure: bool, dir: Direction
|
||||
): Future[Connection] {.async: (raises: [CatchableError]).} =
|
||||
## Returning CatchableError is fine because we later handle different exceptions.
|
||||
##
|
||||
|
||||
let observedAddr =
|
||||
let (observedAddr, localAddr) =
|
||||
try:
|
||||
let
|
||||
codec =
|
||||
@@ -252,15 +255,19 @@ proc connHandler(
|
||||
else:
|
||||
MultiAddress.init("/ws")
|
||||
remoteAddr = stream.stream.reader.tsource.remoteAddress
|
||||
localAddr = stream.stream.reader.tsource.localAddress
|
||||
|
||||
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet()
|
||||
(
|
||||
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet(),
|
||||
MultiAddress.init(localAddr).tryGet() & codec.tryGet(),
|
||||
)
|
||||
except CatchableError as exc:
|
||||
trace "Failed to create observedAddr", description = exc.msg
|
||||
trace "Failed to create observedAddr or listenAddr", description = exc.msg
|
||||
if not (isNil(stream) and stream.stream.reader.closed):
|
||||
safeClose(stream)
|
||||
raise exc
|
||||
|
||||
let conn = WsStream.new(stream, dir, Opt.some(observedAddr))
|
||||
let conn = WsStream.new(stream, dir, Opt.some(observedAddr), Opt.some(localAddr))
|
||||
|
||||
self.connections[dir].add(conn)
|
||||
proc onClose() {.async: (raises: []).} =
|
||||
|
||||
74
libp2p/utils/ipaddr.nim
Normal file
74
libp2p/utils/ipaddr.nim
Normal file
@@ -0,0 +1,74 @@
|
||||
import net, strutils
|
||||
|
||||
import ../switch, ../multiaddress, ../multicodec
|
||||
|
||||
proc isIPv4*(ip: IpAddress): bool =
|
||||
ip.family == IpAddressFamily.IPv4
|
||||
|
||||
proc isIPv6*(ip: IpAddress): bool =
|
||||
ip.family == IpAddressFamily.IPv6
|
||||
|
||||
proc isPrivate*(ip: string): bool {.raises: [ValueError].} =
|
||||
ip.startsWith("10.") or
|
||||
(ip.startsWith("172.") and parseInt(ip.split(".")[1]) in 16 .. 31) or
|
||||
ip.startsWith("192.168.") or ip.startsWith("127.") or ip.startsWith("169.254.")
|
||||
|
||||
proc isPrivate*(ip: IpAddress): bool {.raises: [ValueError].} =
|
||||
isPrivate($ip)
|
||||
|
||||
proc isPublic*(ip: string): bool {.raises: [ValueError].} =
|
||||
not isPrivate(ip)
|
||||
|
||||
proc isPublic*(ip: IpAddress): bool {.raises: [ValueError].} =
|
||||
isPublic($ip)
|
||||
|
||||
proc getPublicIPAddress*(): IpAddress {.raises: [OSError, ValueError].} =
|
||||
let ip =
|
||||
try:
|
||||
getPrimaryIPAddr()
|
||||
except OSError as exc:
|
||||
raise exc
|
||||
except ValueError as exc:
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
raise newException(OSError, "Could not get primary IP address")
|
||||
if not ip.isIPv4():
|
||||
raise newException(ValueError, "Host does not have an IPv4 address")
|
||||
if not ip.isPublic():
|
||||
raise newException(ValueError, "Host does not have a public IPv4 address")
|
||||
ip
|
||||
|
||||
proc ipAddrMatches*(
|
||||
lookup: MultiAddress, addrs: seq[MultiAddress], ip4: bool = true
|
||||
): bool =
|
||||
## Checks ``lookup``'s IP is in any of addrs
|
||||
|
||||
let ipType =
|
||||
if ip4:
|
||||
multiCodec("ip4")
|
||||
else:
|
||||
multiCodec("ip6")
|
||||
|
||||
let lookup = lookup.getPart(ipType).valueOr:
|
||||
return false
|
||||
|
||||
for ma in addrs:
|
||||
ma[0].withValue(ipAddr):
|
||||
if ipAddr == lookup:
|
||||
return true
|
||||
false
|
||||
|
||||
proc ipSupport*(addrs: seq[MultiAddress]): (bool, bool) =
|
||||
## Returns ipv4 and ipv6 support status of a list of MultiAddresses
|
||||
|
||||
var ipv4 = false
|
||||
var ipv6 = false
|
||||
|
||||
for ma in addrs:
|
||||
ma[0].withValue(addrIp):
|
||||
if IP4.match(addrIp):
|
||||
ipv4 = true
|
||||
elif IP6.match(addrIp):
|
||||
ipv6 = true
|
||||
|
||||
(ipv4, ipv6)
|
||||
@@ -11,7 +11,8 @@ fi
|
||||
# Clean up output
|
||||
output_dir="$(pwd)/performance/output"
|
||||
mkdir -p "$output_dir"
|
||||
rm -f "$output_dir"/*.json
|
||||
rm -rf "$output_dir"
|
||||
mkdir -p "$output_dir/sync"
|
||||
|
||||
# Run Test Nodes
|
||||
container_names=()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
import metrics
|
||||
import metrics/chronos_httpserver
|
||||
import os
|
||||
import osproc
|
||||
import strformat
|
||||
import strutils
|
||||
import ../libp2p
|
||||
@@ -42,6 +43,14 @@ proc baseTest*(scenarioName = "Base test") {.async.} =
|
||||
if nodeId == 0:
|
||||
clearSyncFiles()
|
||||
|
||||
# --- Collect docker stats for one publishing and one non-publishing node ---
|
||||
var dockerStatsProc: Process = nil
|
||||
if nodeId == 0 or nodeId == publisherCount + 1:
|
||||
let dockerStatsLogPath = getDockerStatsLogPath(scenario, nodeId)
|
||||
dockerStatsProc = startDockerStatsProcess(nodeId, dockerStatsLogPath)
|
||||
defer:
|
||||
dockerStatsProc.stopDockerStatsProcess()
|
||||
|
||||
let (switch, gossipSub, pingProtocol) = setupNode(nodeId, rng)
|
||||
gossipSub.setGossipSubParams()
|
||||
|
||||
|
||||
87
performance/scripts/add_plots_to_summary.nim
Normal file
87
performance/scripts/add_plots_to_summary.nim
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import algorithm
|
||||
import sequtils
|
||||
import strformat
|
||||
import strutils
|
||||
import tables
|
||||
|
||||
proc getImgUrlBase(repo: string, publishBranchName: string, plotsPath: string): string =
|
||||
&"https://raw.githubusercontent.com/{repo}/refs/heads/{publishBranchName}/{plotsPath}"
|
||||
|
||||
proc extractTestName(base: string): string =
|
||||
let parts = base.split("_")
|
||||
if parts.len >= 2:
|
||||
parts[^2]
|
||||
else:
|
||||
base
|
||||
|
||||
proc makeImgTag(imgUrl: string, width: int): string =
|
||||
&"<img src=\"{imgUrl}\" width=\"{width}\" style=\"margin-right:10px;\" />"
|
||||
|
||||
proc prepareLatencyHistoryImage(
|
||||
imgUrlBase: string, latencyHistoryFilePath: string, width: int = 600
|
||||
): string =
|
||||
let latencyImgUrl = &"{imgUrlBase}/{latencyHistoryFilePath}"
|
||||
makeImgTag(latencyImgUrl, width)
|
||||
|
||||
proc prepareDockerStatsImages(
|
||||
plotDir: string, imgUrlBase: string, branchName: string, width: int = 450
|
||||
): Table[string, seq[string]] =
|
||||
## Groups docker stats plot images by test name and returns HTML <img> tags.
|
||||
var grouped: Table[string, seq[string]]
|
||||
|
||||
for path in walkFiles(&"{plotDir}/*.png"):
|
||||
let plotFile = path.splitPath.tail
|
||||
let testName = extractTestName(plotFile)
|
||||
let imgUrl = &"{imgUrlBase}/{branchName}/{plotFile}"
|
||||
let imgTag = makeImgTag(imgUrl, width)
|
||||
discard grouped.hasKeyOrPut(testName, @[])
|
||||
grouped[testName].add(imgTag)
|
||||
|
||||
grouped
|
||||
|
||||
proc buildSummary(
|
||||
plotDir: string,
|
||||
repo: string,
|
||||
branchName: string,
|
||||
publishBranchName: string,
|
||||
plotsPath: string,
|
||||
latencyHistoryFilePath: string,
|
||||
): string =
|
||||
let imgUrlBase = getImgUrlBase(repo, publishBranchName, plotsPath)
|
||||
|
||||
var buf: seq[string]
|
||||
|
||||
# Latency History section
|
||||
buf.add("## Latency History")
|
||||
buf.add(prepareLatencyHistoryImage(imgUrlBase, latencyHistoryFilePath) & "<br>")
|
||||
buf.add("")
|
||||
|
||||
# Performance Plots section
|
||||
let grouped = prepareDockerStatsImages(plotDir, imgUrlBase, branchName)
|
||||
buf.add(&"## Performance Plots for {branchName}")
|
||||
for test in grouped.keys.toSeq().sorted():
|
||||
let imgs = grouped[test]
|
||||
buf.add(&"### {test}")
|
||||
buf.add(imgs.join(" ") & "<br>")
|
||||
|
||||
buf.join("\n")
|
||||
|
||||
proc main() =
|
||||
let summaryPath = getEnv("GITHUB_STEP_SUMMARY", "/tmp/step_summary.md")
|
||||
let repo = getEnv("GITHUB_REPOSITORY", "vacp2p/nim-libp2p")
|
||||
let branchName = getEnv("BRANCH_NAME", "")
|
||||
let publishBranchName = getEnv("PUBLISH_BRANCH_NAME", "performance_plots")
|
||||
let plotsPath = getEnv("PLOTS_PATH", "plots")
|
||||
let latencyHistoryFilePath =
|
||||
getEnv("LATENCY_HISTORY_PLOT_FILENAME", "latency_history_all_scenarios.png")
|
||||
let checkoutSubfolder = getEnv("CHECKOUT_SUBFOLDER", "subplots")
|
||||
let plotDir = &"{checkoutSubfolder}/{plotsPath}/{branchName}"
|
||||
|
||||
let summary = buildSummary(
|
||||
plotDir, repo, branchName, publishBranchName, plotsPath, latencyHistoryFilePath
|
||||
)
|
||||
writeFile(summaryPath, summary)
|
||||
echo summary
|
||||
|
||||
main()
|
||||
126
performance/scripts/plot_docker_stats.py
Normal file
126
performance/scripts/plot_docker_stats.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
import glob
|
||||
import csv
|
||||
import statistics
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def parse_csv(filepath):
|
||||
timestamps = []
|
||||
cpu_percent = []
|
||||
mem_usage_mb = []
|
||||
download_MBps = []
|
||||
upload_MBps = []
|
||||
download_MB = []
|
||||
upload_MB = []
|
||||
with open(filepath, "r") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
timestamps.append(float(row["timestamp"]))
|
||||
cpu_percent.append(float(row["cpu_percent"]))
|
||||
mem_usage_mb.append(float(row["mem_usage_mb"]))
|
||||
download_MBps.append(float(row["download_MBps"]))
|
||||
upload_MBps.append(float(row["upload_MBps"]))
|
||||
download_MB.append(float(row["download_MB"]))
|
||||
upload_MB.append(float(row["upload_MB"]))
|
||||
return {
|
||||
"timestamps": timestamps,
|
||||
"cpu_percent": cpu_percent,
|
||||
"mem_usage_mb": mem_usage_mb,
|
||||
"download_MBps": download_MBps,
|
||||
"upload_MBps": upload_MBps,
|
||||
"download_MB": download_MB,
|
||||
"upload_MB": upload_MB,
|
||||
}
|
||||
|
||||
|
||||
def plot_metrics(data, title, output_path):
|
||||
timestamps = data["timestamps"]
|
||||
time_points = [t - timestamps[0] for t in timestamps]
|
||||
cpu = data["cpu_percent"]
|
||||
mem = data["mem_usage_mb"]
|
||||
download_MBps = data["download_MBps"]
|
||||
upload_MBps = data["upload_MBps"]
|
||||
download_MB = data["download_MB"]
|
||||
upload_MB = data["upload_MB"]
|
||||
|
||||
cpu_median = statistics.median(cpu)
|
||||
cpu_max = max(cpu)
|
||||
mem_median = statistics.median(mem)
|
||||
mem_max = max(mem)
|
||||
download_MBps_median = statistics.median(download_MBps)
|
||||
download_MBps_max = max(download_MBps)
|
||||
upload_MBps_median = statistics.median(upload_MBps)
|
||||
upload_MBps_max = max(upload_MBps)
|
||||
download_MB_total = download_MB[-1]
|
||||
upload_MB_total = upload_MB[-1]
|
||||
|
||||
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
|
||||
fig.suptitle(title, fontsize=16)
|
||||
|
||||
# CPU Usage
|
||||
ax1.plot(time_points, cpu, "b-", label=f"CPU Usage (%)\nmedian = {cpu_median:.2f}\nmax = {cpu_max:.2f}")
|
||||
ax1.set_ylabel("CPU Usage (%)")
|
||||
ax1.set_title("CPU Usage Over Time")
|
||||
ax1.grid(True)
|
||||
ax1.set_xlim(left=0)
|
||||
ax1.set_ylim(bottom=0)
|
||||
ax1.legend(loc="best")
|
||||
|
||||
# Memory Usage
|
||||
ax2.plot(time_points, mem, "m-", label=f"Memory Usage (MB)\nmedian = {mem_median:.2f} MB\nmax = {mem_max:.2f} MB")
|
||||
ax2.set_ylabel("Memory Usage (MB)")
|
||||
ax2.set_title("Memory Usage Over Time")
|
||||
ax2.grid(True)
|
||||
ax2.set_xlim(left=0)
|
||||
ax2.set_ylim(bottom=0)
|
||||
ax2.legend(loc="best")
|
||||
|
||||
# Network Throughput
|
||||
ax3.plot(
|
||||
time_points,
|
||||
download_MBps,
|
||||
"c-",
|
||||
label=f"Download (MB/s)\nmedian = {download_MBps_median:.2f} MB/s\nmax = {download_MBps_max:.2f} MB/s",
|
||||
linewidth=2,
|
||||
)
|
||||
ax3.plot(
|
||||
time_points, upload_MBps, "r-", label=f"Upload (MB/s)\nmedian = {upload_MBps_median:.2f} MB/s\nmax = {upload_MBps_max:.2f} MB/s", linewidth=2
|
||||
)
|
||||
ax3.set_ylabel("Network Throughput (MB/s)")
|
||||
ax3.set_title("Network Activity Over Time")
|
||||
ax3.grid(True)
|
||||
ax3.set_xlim(left=0)
|
||||
ax3.set_ylim(bottom=0)
|
||||
ax3.legend(loc="best", labelspacing=2)
|
||||
|
||||
# Accumulated Network Data
|
||||
ax4.plot(time_points, download_MB, "c-", label=f"Download (MB), total: {download_MB_total:.2f} MB", linewidth=2)
|
||||
ax4.plot(time_points, upload_MB, "r-", label=f"Upload (MB), total: {upload_MB_total:.2f} MB", linewidth=2)
|
||||
ax4.set_xlabel("Time (seconds)")
|
||||
ax4.set_ylabel("Total Data Transferred (MB)")
|
||||
ax4.set_title("Accumulated Network Data Over Time")
|
||||
ax4.grid(True)
|
||||
ax4.set_xlim(left=0)
|
||||
ax4.set_ylim(bottom=0)
|
||||
ax4.legend(loc="best")
|
||||
|
||||
plt.tight_layout(rect=(0, 0, 1, 1))
|
||||
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||||
plt.savefig(output_path, dpi=100, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
print(f"Saved plot to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
shared_volume_path = os.environ.get("SHARED_VOLUME_PATH", "performance/output")
|
||||
docker_stats_prefix = os.environ.get("DOCKER_STATS_PREFIX", "docker_stats_")
|
||||
glob_pattern = os.path.join(shared_volume_path, f"{docker_stats_prefix}*.csv")
|
||||
csv_files = glob.glob(glob_pattern)
|
||||
for csv_file in csv_files:
|
||||
file_name = os.path.splitext(os.path.basename(csv_file))[0]
|
||||
data = parse_csv(csv_file)
|
||||
plot_metrics(data, title=file_name, output_path=os.path.join(shared_volume_path, f"{file_name}.png"))
|
||||
99
performance/scripts/plot_latency_history.py
Normal file
99
performance/scripts/plot_latency_history.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import glob
|
||||
import csv
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def extract_pr_number(filename):
|
||||
"""Extract PR number from filename of format pr{number}_anything.csv"""
|
||||
fname = os.path.basename(filename)
|
||||
parts = fname.split("_", 1)
|
||||
pr_str = parts[0][2:]
|
||||
if not pr_str.isdigit():
|
||||
return None
|
||||
return int(pr_str)
|
||||
|
||||
|
||||
def parse_latency_csv(csv_files):
|
||||
pr_numbers = []
|
||||
scenario_data = {} # scenario -> {pr_num: {min, avg, max}}
|
||||
for csv_file in csv_files:
|
||||
pr_num = extract_pr_number(csv_file)
|
||||
if pr_num is None:
|
||||
continue
|
||||
pr_numbers.append(pr_num)
|
||||
with open(csv_file, newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
scenario = row["Scenario"]
|
||||
if scenario not in scenario_data:
|
||||
scenario_data[scenario] = {}
|
||||
scenario_data[scenario][pr_num] = {
|
||||
"min": float(row["MinLatencyMs"]),
|
||||
"avg": float(row["AvgLatencyMs"]),
|
||||
"max": float(row["MaxLatencyMs"]),
|
||||
}
|
||||
pr_numbers = sorted(set(pr_numbers))
|
||||
return pr_numbers, scenario_data
|
||||
|
||||
|
||||
def plot_latency_history(pr_numbers, scenario_data, output_path):
|
||||
if not pr_numbers or not scenario_data:
|
||||
print("No PR latency data found; skipping plot generation.")
|
||||
return
|
||||
|
||||
num_scenarios = len(scenario_data)
|
||||
fig, axes = plt.subplots(num_scenarios, 1, figsize=(14, 4 * num_scenarios), sharex=True)
|
||||
if num_scenarios == 1:
|
||||
axes = [axes]
|
||||
|
||||
color_map = plt.colormaps.get_cmap("tab10")
|
||||
|
||||
x_positions = list(range(len(pr_numbers)))
|
||||
|
||||
for i, (scenario, pr_stats) in enumerate(scenario_data.items()):
|
||||
ax = axes[i]
|
||||
min_vals = [pr_stats.get(pr, {"min": None})["min"] for pr in pr_numbers]
|
||||
avg_vals = [pr_stats.get(pr, {"avg": None})["avg"] for pr in pr_numbers]
|
||||
max_vals = [pr_stats.get(pr, {"max": None})["max"] for pr in pr_numbers]
|
||||
|
||||
color = color_map(i % color_map.N)
|
||||
|
||||
if any(v is not None for v in avg_vals):
|
||||
ax.plot(x_positions, avg_vals, marker="o", label="Avg Latency (ms)", color=color)
|
||||
ax.fill_between(x_positions, min_vals, max_vals, color=color, alpha=0.2, label="Min-Max Latency (ms)")
|
||||
for x, avg, minv, maxv in zip(x_positions, avg_vals, min_vals, max_vals):
|
||||
if avg is not None:
|
||||
ax.scatter(x, avg, color=color)
|
||||
ax.text(x, avg, f"{avg:.3f}", fontsize=14, ha="center", va="bottom")
|
||||
if minv is not None and maxv is not None:
|
||||
ax.vlines(x, minv, maxv, color=color, alpha=0.5)
|
||||
|
||||
ax.set_ylabel("Latency (ms)")
|
||||
ax.set_title(f"Scenario: {scenario}")
|
||||
ax.legend(loc="upper left", fontsize="small")
|
||||
ax.grid(True, linestyle="--", alpha=0.5)
|
||||
|
||||
# Set X axis ticks and labels to show all PR numbers as 'PR <number>'
|
||||
axes[-1].set_xlabel("PR Number")
|
||||
axes[-1].set_xticks(x_positions)
|
||||
axes[-1].set_xticklabels([f"PR {pr}" for pr in pr_numbers], rotation=45, ha="right", fontsize=14)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path)
|
||||
print(f"Saved combined plot to {output_path}")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
LATENCY_HISTORY_PATH = os.environ.get("LATENCY_HISTORY_PATH", "performance/output")
|
||||
LATENCY_HISTORY_PREFIX = os.environ.get("LATENCY_HISTORY_PREFIX", "pr")
|
||||
LATENCY_HISTORY_PLOT_FILENAME = os.environ.get("LATENCY_HISTORY_PLOT_FILENAME", "pr")
|
||||
glob_pattern = os.path.join(LATENCY_HISTORY_PATH, f"{LATENCY_HISTORY_PREFIX}[0-9]*_latency.csv")
|
||||
csv_files = sorted(glob.glob(glob_pattern))
|
||||
pr_numbers, scenario_data = parse_latency_csv(csv_files)
|
||||
output_path = os.path.join(LATENCY_HISTORY_PATH, LATENCY_HISTORY_PLOT_FILENAME)
|
||||
plot_latency_history(pr_numbers, scenario_data, output_path)
|
||||
156
performance/scripts/process_docker_stats.nim
Normal file
156
performance/scripts/process_docker_stats.nim
Normal file
@@ -0,0 +1,156 @@
|
||||
from times import parse, toTime, toUnix
|
||||
import strformat
|
||||
import strutils
|
||||
import json
|
||||
import os
|
||||
import options
|
||||
|
||||
type DockerStatsSample = object
|
||||
timestamp: float
|
||||
cpuPercent: float
|
||||
memUsageMB: float
|
||||
netRxMB: float
|
||||
netTxMB: float
|
||||
|
||||
proc parseTimestamp(statsJson: JsonNode): float =
|
||||
let isoStr = statsJson["read"].getStr("")
|
||||
let mainPart = isoStr[0 ..< ^1] # remove trailing 'Z'
|
||||
let parts = mainPart.split(".")
|
||||
let dt = parse(parts[0], "yyyy-MM-dd'T'HH:mm:ss")
|
||||
|
||||
var nanos = 0
|
||||
if parts.len == 2:
|
||||
let nsStr = parts[1]
|
||||
let nsStrPadded = nsStr & repeat('0', 9 - nsStr.len)
|
||||
nanos = parseInt(nsStrPadded)
|
||||
let epochNano = dt.toTime.toUnix * 1_000_000_000 + nanos
|
||||
|
||||
# Return timestamp in seconds since Unix epoch
|
||||
return float(epochNano) / 1_000_000_000.0
|
||||
|
||||
proc extractCpuRaw(statsJson: JsonNode): (int, int, int) =
|
||||
let cpuStats = statsJson["cpu_stats"]
|
||||
let precpuStats = statsJson["precpu_stats"]
|
||||
let totalUsage = cpuStats["cpu_usage"]["total_usage"].getInt(0)
|
||||
let prevTotalUsage = precpuStats["cpu_usage"]["total_usage"].getInt(0)
|
||||
let systemUsage = cpuStats["system_cpu_usage"].getInt(0)
|
||||
let prevSystemUsage = precpuStats["system_cpu_usage"].getInt(0)
|
||||
let numCpus = cpuStats["online_cpus"].getInt(0)
|
||||
return (totalUsage - prevTotalUsage, systemUsage - prevSystemUsage, numCpus)
|
||||
|
||||
proc calcCpuPercent(cpuDelta: int, systemDelta: int, numCpus: int): float =
|
||||
if systemDelta > 0 and cpuDelta > 0 and numCpus > 0:
|
||||
return (float(cpuDelta) / float(systemDelta)) * float(numCpus) * 100.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
proc extractMemUsageRaw(statsJson: JsonNode): int =
|
||||
let memStats = statsJson["memory_stats"]
|
||||
return memStats["usage"].getInt(0)
|
||||
|
||||
proc extractNetworkRaw(statsJson: JsonNode): (int, int) =
|
||||
var netRxBytes = 0
|
||||
var netTxBytes = 0
|
||||
if "networks" in statsJson:
|
||||
for k, v in statsJson["networks"]:
|
||||
netRxBytes += v["rx_bytes"].getInt(0)
|
||||
netTxBytes += v["tx_bytes"].getInt(0)
|
||||
return (netRxBytes, netTxBytes)
|
||||
|
||||
proc convertMB(bytes: int): float =
|
||||
return float(bytes) / 1024.0 / 1024.0
|
||||
|
||||
proc parseDockerStatsLine(line: string): Option[DockerStatsSample] =
|
||||
var samples = none(DockerStatsSample)
|
||||
if line.len == 0:
|
||||
return samples
|
||||
try:
|
||||
let statsJson = parseJson(line)
|
||||
let timestamp = parseTimestamp(statsJson)
|
||||
let (cpuDelta, systemDelta, numCpus) = extractCpuRaw(statsJson)
|
||||
let cpuPercent = calcCpuPercent(cpuDelta, systemDelta, numCpus)
|
||||
let memUsageMB = extractMemUsageRaw(statsJson).convertMB()
|
||||
let (netRxRaw, netTxRaw) = extractNetworkRaw(statsJson)
|
||||
let netRxMB = netRxRaw.convertMB()
|
||||
let netTxMB = netTxRaw.convertMB()
|
||||
return some(
|
||||
DockerStatsSample(
|
||||
timestamp: timestamp,
|
||||
cpuPercent: cpuPercent,
|
||||
memUsageMB: memUsageMB,
|
||||
netRxMB: netRxMB,
|
||||
netTxMB: netTxMB,
|
||||
)
|
||||
)
|
||||
except:
|
||||
return samples
|
||||
|
||||
proc processDockerStatsLog*(inputPath: string): seq[DockerStatsSample] =
|
||||
var samples: seq[DockerStatsSample]
|
||||
for line in lines(inputPath):
|
||||
let sampleOpt = parseDockerStatsLine(line)
|
||||
if sampleOpt.isSome:
|
||||
samples.add(sampleOpt.get)
|
||||
return samples
|
||||
|
||||
proc calcRateMBps(curr: float, prev: float, dt: float): float =
|
||||
if dt == 0:
|
||||
return 0.0
|
||||
return ((curr - prev)) / dt
|
||||
|
||||
proc writeCsvSeries(samples: seq[DockerStatsSample], outPath: string) =
|
||||
var f = open(outPath, fmWrite)
|
||||
f.writeLine(
|
||||
"timestamp,cpu_percent,mem_usage_mb,download_MBps,upload_MBps,download_MB,upload_MB"
|
||||
)
|
||||
if samples.len == 0:
|
||||
f.close()
|
||||
return
|
||||
let timeOffset = samples[0].timestamp
|
||||
let memOffset = samples[0].memUsageMB
|
||||
let rxOffset = samples[0].netRxMB
|
||||
let txOffset = samples[0].netTxMB
|
||||
var prevRx = samples[0].netRxMB
|
||||
var prevTx = samples[0].netTxMB
|
||||
var prevTimestamp = samples[0].timestamp - timeOffset
|
||||
for s in samples:
|
||||
let relTimestamp = s.timestamp - timeOffset
|
||||
let dt = relTimestamp - prevTimestamp
|
||||
let dlMBps = calcRateMBps(s.netRxMB, prevRx, dt)
|
||||
let ulMBps = calcRateMBps(s.netTxMB, prevTx, dt)
|
||||
let dlAcc = s.netRxMB - rxOffset
|
||||
let ulAcc = s.netTxMB - txOffset
|
||||
let memUsage = s.memUsageMB - memOffset
|
||||
f.writeLine(
|
||||
fmt"{relTimestamp:.2f},{s.cpuPercent:.2f},{memUsage:.2f},{dlMBps:.4f},{ulMBps:.4f},{dlAcc:.4f},{ulAcc:.4f}"
|
||||
)
|
||||
prevRx = s.netRxMB
|
||||
prevTx = s.netTxMB
|
||||
prevTimestamp = relTimestamp
|
||||
f.close()
|
||||
|
||||
proc findInputFiles(dir: string, prefix: string): seq[string] =
|
||||
var files: seq[string] = @[]
|
||||
for entry in walkDir(dir):
|
||||
if entry.kind == pcFile and entry.path.endsWith(".log") and
|
||||
entry.path.contains(prefix):
|
||||
files.add(entry.path)
|
||||
return files
|
||||
|
||||
proc main() =
|
||||
let dir = getEnv("SHARED_VOLUME_PATH", "performance/output")
|
||||
let prefix = getEnv("DOCKER_STATS_PREFIX", "docker_stats_")
|
||||
|
||||
let inputFiles = findInputFiles(dir, prefix)
|
||||
if inputFiles.len == 0:
|
||||
echo "No docker stats files found."
|
||||
return
|
||||
|
||||
for inputFile in inputFiles:
|
||||
let processedStats = processDockerStatsLog(inputFile)
|
||||
let outputCsvPath = inputFile.replace(".log", ".csv")
|
||||
|
||||
writeCsvSeries(processedStats, outputCsvPath)
|
||||
echo fmt"Processed stats from {inputFile} written to {outputCsvPath}"
|
||||
|
||||
main()
|
||||
@@ -4,7 +4,7 @@ import sequtils
|
||||
import strutils
|
||||
import strformat
|
||||
import tables
|
||||
import ./types
|
||||
import ../types
|
||||
|
||||
const unknownFloat = -1.0
|
||||
|
||||
@@ -93,7 +93,8 @@ proc getMarkdownReport*(
|
||||
output.add marker & "\n"
|
||||
output.add "# 🏁 **Performance Summary**\n"
|
||||
|
||||
output.add fmt"**Commit:** `{commitSha}`"
|
||||
let commitUrl = fmt"https://github.com/vacp2p/nim-libp2p/commit/{commitSha}"
|
||||
output.add fmt"**Commit:** [`{commitSha}`]({commitUrl})"
|
||||
|
||||
output.add "| Scenario | Nodes | Total messages sent | Total messages received | Latency min (ms) | Latency max (ms) | Latency avg (ms) |"
|
||||
output.add "|:---:|:---:|:---:|:---:|:---:|:---:|:---:|"
|
||||
@@ -102,10 +103,29 @@ proc getMarkdownReport*(
|
||||
let nodes = validNodes[scenarioName]
|
||||
output.add fmt"| {stats.scenarioName} | {nodes} | {stats.totalSent} | {stats.totalReceived} | {stats.latency.minLatencyMs:.3f} | {stats.latency.maxLatencyMs:.3f} | {stats.latency.avgLatencyMs:.3f} |"
|
||||
|
||||
let markdown = output.join("\n")
|
||||
let runId = getEnv("GITHUB_RUN_ID", "")
|
||||
let summaryUrl = fmt"https://github.com/vacp2p/nim-libp2p/actions/runs/{runId}"
|
||||
output.add(
|
||||
fmt"### 📊 View Latency History and full Container Resources in the [Workflow Summary]({summaryUrl})"
|
||||
)
|
||||
|
||||
let markdown = output.join("\n")
|
||||
return markdown
|
||||
|
||||
proc getCsvFilename*(outputDir: string): string =
|
||||
let prNum = getEnv("PR_NUMBER", "unknown")
|
||||
result = fmt"{outputDir}/pr{prNum}_latency.csv"
|
||||
|
||||
proc getCsvReport*(
|
||||
results: Table[string, Stats], validNodes: Table[string, int]
|
||||
): string =
|
||||
var output: seq[string]
|
||||
output.add "Scenario,Nodes,TotalSent,TotalReceived,MinLatencyMs,MaxLatencyMs,AvgLatencyMs"
|
||||
for scenarioName, stats in results.pairs:
|
||||
let nodes = validNodes[scenarioName]
|
||||
output.add fmt"{stats.scenarioName},{nodes},{stats.totalSent},{stats.totalReceived},{stats.latency.minLatencyMs:.3f},{stats.latency.maxLatencyMs:.3f},{stats.latency.avgLatencyMs:.3f}"
|
||||
result = output.join("\n")
|
||||
|
||||
proc main() =
|
||||
let outputDir = "performance/output"
|
||||
let parsedJsons = parseJsonFiles(outputDir)
|
||||
@@ -113,6 +133,11 @@ proc main() =
|
||||
let jsonResults = getJsonResults(parsedJsons)
|
||||
let (aggregatedResults, validNodes) = aggregateResults(jsonResults)
|
||||
|
||||
# For History
|
||||
let csvFilename = getCsvFilename(outputDir)
|
||||
let csvContent = getCsvReport(aggregatedResults, validNodes)
|
||||
writeFile(csvFilename, csvContent)
|
||||
|
||||
let marker = getEnv("MARKER", "<!-- marker -->")
|
||||
let commitSha = getEnv("PR_HEAD_SHA", getEnv("GITHUB_SHA", "unknown"))
|
||||
let markdown = getMarkdownReport(aggregatedResults, validNodes, marker, commitSha)
|
||||
@@ -266,7 +266,7 @@ proc syncNodes*(stage: string, nodeId, nodeCount: int) {.async.} =
|
||||
writeFile(myFile, "ok")
|
||||
|
||||
let expectedFiles = (0 ..< nodeCount).mapIt(syncDir / (prefix & stage & "_" & $it))
|
||||
checkUntilTimeoutCustom(5.seconds, 100.milliseconds):
|
||||
checkUntilTimeoutCustom(15.seconds, 100.milliseconds):
|
||||
expectedFiles.allIt(fileExists(it))
|
||||
|
||||
# final wait
|
||||
@@ -279,3 +279,38 @@ proc clearSyncFiles*() =
|
||||
for f in walkDir(syncDir):
|
||||
if fileExists(f.path):
|
||||
removeFile(f.path)
|
||||
|
||||
proc getDockerStatsLogPath*(scenarioName: string, nodeId: int): string =
|
||||
let sanitizedScenario = scenarioName.replace(" ", "").replace("%", "percent")
|
||||
return &"/output/docker_stats_{sanitizedScenario}_{nodeId}.log"
|
||||
|
||||
proc clearDockerStats*(outputPath: string) =
|
||||
if fileExists(outputPath):
|
||||
removeFile(outputPath)
|
||||
|
||||
proc getContainerId(nodeId: int): string =
|
||||
let response = execShellCommand(
|
||||
"curl -s --unix-socket /var/run/docker.sock http://localhost/containers/json"
|
||||
)
|
||||
let containers = parseJson(response)
|
||||
let expectedName = "/node-" & $nodeId
|
||||
let filtered =
|
||||
containers.filterIt(it["Names"].getElems(@[]).anyIt(it.getStr("") == expectedName))
|
||||
if filtered.len == 0:
|
||||
return ""
|
||||
return filtered[0]["Id"].getStr("")
|
||||
|
||||
proc startDockerStatsProcess*(nodeId: int, outputPath: string): Process =
|
||||
let containerId = getContainerId(nodeId)
|
||||
|
||||
let shellCmd =
|
||||
fmt"curl --unix-socket /var/run/docker.sock http://localhost/containers/{containerId}/stats > {outputPath} 2>/dev/null"
|
||||
|
||||
return startProcess(
|
||||
"/bin/sh", args = ["-c", shellCmd], options = {poUsePath, poStdErrToStdOut}
|
||||
)
|
||||
|
||||
proc stopDockerStatsProcess*(p: Process) =
|
||||
if p != nil:
|
||||
p.kill()
|
||||
p.close()
|
||||
|
||||
@@ -516,6 +516,7 @@ proc relayInteropTests*(name: string, relayCreator: SwitchCreator) =
|
||||
suite "Interop relay using " & name:
|
||||
asyncTest "NativeSrc -> NativeRelay -> DaemonDst":
|
||||
let closeBlocker = newFuture[void]()
|
||||
let daemonFinished = newFuture[void]()
|
||||
# TODO: This Future blocks the daemonHandler after sending the last message.
|
||||
# It exists because there's a strange behavior where stream.close sends
|
||||
# a Rst instead of Fin. We should investigate this at some point.
|
||||
@@ -528,6 +529,7 @@ proc relayInteropTests*(name: string, relayCreator: SwitchCreator) =
|
||||
discard await stream.transp.writeLp("line4")
|
||||
await closeBlocker
|
||||
await stream.close()
|
||||
daemonFinished.complete()
|
||||
|
||||
let
|
||||
maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
|
||||
@@ -556,8 +558,18 @@ proc relayInteropTests*(name: string, relayCreator: SwitchCreator) =
|
||||
check string.fromBytes(await conn.readLp(1024)) == "line4"
|
||||
|
||||
closeBlocker.complete()
|
||||
await daemonFinished
|
||||
await conn.close()
|
||||
await allFutures(src.stop(), rel.stop())
|
||||
await daemonNode.close()
|
||||
try:
|
||||
await daemonNode.close()
|
||||
except CatchableError as e:
|
||||
when defined(windows):
|
||||
# On Windows, daemon close may fail due to socket race condition
|
||||
# This is expected behavior and can be safely ignored
|
||||
discard
|
||||
else:
|
||||
raise e
|
||||
|
||||
asyncTest "DaemonSrc -> NativeRelay -> NativeDst":
|
||||
proc customHandler(
|
||||
|
||||
@@ -32,6 +32,9 @@ template commonTransportTest*(prov: TransportBuilder, ma1: string, ma2: string =
|
||||
let conn = await transport1.accept()
|
||||
if conn.observedAddr.isSome():
|
||||
check transport1.handles(conn.observedAddr.get())
|
||||
# skip IP check, only check transport and port
|
||||
check conn.localAddr.get()[3] == transport1.addrs[0][3]
|
||||
check conn.localAddr.get()[4] == transport1.addrs[0][4]
|
||||
await conn.close()
|
||||
|
||||
let handlerWait = acceptHandler()
|
||||
|
||||
3
tests/discovery/testdiscovery.nim
Normal file
3
tests/discovery/testdiscovery.nim
Normal file
@@ -0,0 +1,3 @@
|
||||
{.used.}
|
||||
|
||||
import testdiscoverymngr, testrendezvous, testrendezvousinterface
|
||||
@@ -11,25 +11,16 @@
|
||||
|
||||
import options, chronos, sets
|
||||
import
|
||||
../libp2p/[
|
||||
../../libp2p/[
|
||||
protocols/rendezvous,
|
||||
switch,
|
||||
builders,
|
||||
discovery/discoverymngr,
|
||||
discovery/rendezvousinterface,
|
||||
]
|
||||
import ./helpers, ./utils/async_tests
|
||||
|
||||
proc createSwitch(rdv: RendezVous = RendezVous.new()): Switch =
|
||||
SwitchBuilder
|
||||
.new()
|
||||
.withRng(newRng())
|
||||
.withAddresses(@[MultiAddress.init(MemoryAutoAddress).tryGet()])
|
||||
.withMemoryTransport()
|
||||
.withMplex()
|
||||
.withNoise()
|
||||
.withRendezVous(rdv)
|
||||
.build()
|
||||
import ../helpers
|
||||
import ../utils/async_tests
|
||||
import ./utils
|
||||
|
||||
suite "Discovery":
|
||||
teardown:
|
||||
471
tests/discovery/testrendezvous.nim
Normal file
471
tests/discovery/testrendezvous.nim
Normal file
@@ -0,0 +1,471 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import sequtils, strutils
|
||||
import chronos
|
||||
import ../../libp2p/[protocols/rendezvous, switch]
|
||||
import ../../libp2p/discovery/discoverymngr
|
||||
import ../../libp2p/utils/offsettedseq
|
||||
import ../helpers
|
||||
import ./utils
|
||||
|
||||
suite "RendezVous":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
asyncTest "Request locally returns 0 for empty namespace":
|
||||
let (nodes, rdvs) = setupNodes(1)
|
||||
nodes.startAndDeferStop()
|
||||
|
||||
const namespace = ""
|
||||
check rdvs[0].requestLocally(namespace).len == 0
|
||||
|
||||
asyncTest "Request locally returns registered peers":
|
||||
let (nodes, rdvs) = setupNodes(1)
|
||||
nodes.startAndDeferStop()
|
||||
|
||||
const namespace = "foo"
|
||||
await rdvs[0].advertise(namespace)
|
||||
let peerRecords = rdvs[0].requestLocally(namespace)
|
||||
|
||||
check:
|
||||
peerRecords.len == 1
|
||||
peerRecords[0] == nodes[0].peerInfo.signedPeerRecord.data
|
||||
|
||||
asyncTest "Unsubscribe Locally removes registered peer":
|
||||
let (nodes, rdvs) = setupNodes(1)
|
||||
nodes.startAndDeferStop()
|
||||
|
||||
const namespace = "foo"
|
||||
await rdvs[0].advertise(namespace)
|
||||
check rdvs[0].requestLocally(namespace).len == 1
|
||||
|
||||
rdvs[0].unsubscribeLocally(namespace)
|
||||
check rdvs[0].requestLocally(namespace).len == 0
|
||||
|
||||
asyncTest "Request returns 0 for empty namespace from remote":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(1)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodes(peerNodes[0], rendezvousNode)
|
||||
|
||||
const namespace = "empty"
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
|
||||
|
||||
asyncTest "Request returns registered peers from remote":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(1)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodes(peerNodes[0], rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
await peerRdvs[0].advertise(namespace)
|
||||
let peerRecords = await peerRdvs[0].request(Opt.some(namespace))
|
||||
check:
|
||||
peerRecords.len == 1
|
||||
peerRecords[0] == peerNodes[0].peerInfo.signedPeerRecord.data
|
||||
|
||||
asyncTest "Unsubscribe removes registered peer from remote":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(1)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodes(peerNodes[0], rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
await peerRdvs[0].advertise(namespace)
|
||||
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 1
|
||||
|
||||
await peerRdvs[0].unsubscribe(namespace)
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
|
||||
|
||||
asyncTest "Consecutive requests with namespace returns peers with pagination":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(11)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
await allFutures(peerRdvs.mapIt(it.advertise(namespace)))
|
||||
|
||||
var data = peerNodes.mapIt(it.peerInfo.signedPeerRecord.data)
|
||||
var peerRecords = await peerRdvs[0].request(Opt.some(namespace), 5)
|
||||
check:
|
||||
peerRecords.len == 5
|
||||
peerRecords.allIt(it in data)
|
||||
data.keepItIf(it notin peerRecords)
|
||||
|
||||
peerRecords = await peerRdvs[0].request(Opt.some(namespace))
|
||||
check:
|
||||
peerRecords.len == 6
|
||||
peerRecords.allIt(it in data)
|
||||
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
|
||||
|
||||
asyncTest "Request without namespace returns all registered peers":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(10)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
const namespaceFoo = "foo"
|
||||
const namespaceBar = "Bar"
|
||||
await allFutures(peerRdvs[0 ..< 5].mapIt(it.advertise(namespaceFoo)))
|
||||
await allFutures(peerRdvs[5 ..< 10].mapIt(it.advertise(namespaceBar)))
|
||||
|
||||
check (await peerRdvs[0].request()).len == 10
|
||||
|
||||
check (await peerRdvs[0].request(Opt.none(string))).len == 10
|
||||
|
||||
asyncTest "Consecutive requests with namespace keep cookie and retun only new peers":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(2)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
let
|
||||
rdv0 = peerRdvs[0]
|
||||
rdv1 = peerRdvs[1]
|
||||
const namespace = "foo"
|
||||
|
||||
await rdv0.advertise(namespace)
|
||||
discard await rdv0.request(Opt.some(namespace))
|
||||
|
||||
await rdv1.advertise(namespace)
|
||||
let peerRecords = await rdv0.request(Opt.some(namespace))
|
||||
|
||||
check:
|
||||
peerRecords.len == 1
|
||||
peerRecords[0] == peerNodes[1].peerInfo.signedPeerRecord.data
|
||||
|
||||
asyncTest "Request with namespace pagination with multiple namespaces":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, _) = setupRendezvousNodeWithPeerNodes(30)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
let rdv = peerRdvs[0]
|
||||
|
||||
# Register peers in different namespaces in mixed order
|
||||
const
|
||||
namespaceFoo = "foo"
|
||||
namespaceBar = "bar"
|
||||
await allFutures(peerRdvs[0 ..< 5].mapIt(it.advertise(namespaceFoo)))
|
||||
await allFutures(peerRdvs[5 ..< 10].mapIt(it.advertise(namespaceBar)))
|
||||
await allFutures(peerRdvs[10 ..< 15].mapIt(it.advertise(namespaceFoo)))
|
||||
await allFutures(peerRdvs[15 ..< 20].mapIt(it.advertise(namespaceBar)))
|
||||
|
||||
var fooRecords = peerNodes[0 ..< 5].concat(peerNodes[10 ..< 15]).mapIt(
|
||||
it.peerInfo.signedPeerRecord.data
|
||||
)
|
||||
var barRecords = peerNodes[5 ..< 10].concat(peerNodes[15 ..< 20]).mapIt(
|
||||
it.peerInfo.signedPeerRecord.data
|
||||
)
|
||||
|
||||
# Foo Page 1 with limit
|
||||
var peerRecords = await rdv.request(Opt.some(namespaceFoo), 2)
|
||||
check:
|
||||
peerRecords.len == 2
|
||||
peerRecords.allIt(it in fooRecords)
|
||||
fooRecords.keepItIf(it notin peerRecords)
|
||||
|
||||
# Foo Page 2 with limit
|
||||
peerRecords = await rdv.request(Opt.some(namespaceFoo), 5)
|
||||
check:
|
||||
peerRecords.len == 5
|
||||
peerRecords.allIt(it in fooRecords)
|
||||
fooRecords.keepItIf(it notin peerRecords)
|
||||
|
||||
# Foo Page 3 with the rest
|
||||
peerRecords = await rdv.request(Opt.some(namespaceFoo))
|
||||
check:
|
||||
peerRecords.len == 3
|
||||
peerRecords.allIt(it in fooRecords)
|
||||
fooRecords.keepItIf(it notin peerRecords)
|
||||
check fooRecords.len == 0
|
||||
|
||||
# Foo Page 4 empty
|
||||
peerRecords = await rdv.request(Opt.some(namespaceFoo))
|
||||
check peerRecords.len == 0
|
||||
|
||||
# Bar Page 1 with all
|
||||
peerRecords = await rdv.request(Opt.some(namespaceBar), 30)
|
||||
check:
|
||||
peerRecords.len == 10
|
||||
peerRecords.allIt(it in barRecords)
|
||||
barRecords.keepItIf(it notin peerRecords)
|
||||
check barRecords.len == 0
|
||||
|
||||
# Register new peers
|
||||
await allFutures(peerRdvs[20 ..< 25].mapIt(it.advertise(namespaceFoo)))
|
||||
await allFutures(peerRdvs[25 ..< 30].mapIt(it.advertise(namespaceBar)))
|
||||
|
||||
# Foo Page 5 only new peers
|
||||
peerRecords = await rdv.request(Opt.some(namespaceFoo))
|
||||
check:
|
||||
peerRecords.len == 5
|
||||
peerRecords.allIt(
|
||||
it in peerNodes[20 ..< 25].mapIt(it.peerInfo.signedPeerRecord.data)
|
||||
)
|
||||
|
||||
# Bar Page 2 only new peers
|
||||
peerRecords = await rdv.request(Opt.some(namespaceBar))
|
||||
check:
|
||||
peerRecords.len == 5
|
||||
peerRecords.allIt(
|
||||
it in peerNodes[25 ..< 30].mapIt(it.peerInfo.signedPeerRecord.data)
|
||||
)
|
||||
|
||||
# All records
|
||||
peerRecords = await rdv.request(Opt.none(string))
|
||||
check peerRecords.len == 30
|
||||
|
||||
asyncTest "Request with namespace with expired peers":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rdv) =
|
||||
setupRendezvousNodeWithPeerNodes(20)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
# Advertise peers
|
||||
const
|
||||
namespaceFoo = "foo"
|
||||
namespaceBar = "bar"
|
||||
await allFutures(peerRdvs[0 ..< 5].mapIt(it.advertise(namespaceFoo)))
|
||||
await allFutures(peerRdvs[5 ..< 10].mapIt(it.advertise(namespaceBar)))
|
||||
|
||||
check:
|
||||
(await peerRdvs[0].request(Opt.some(namespaceFoo))).len == 5
|
||||
(await peerRdvs[0].request(Opt.some(namespaceBar))).len == 5
|
||||
|
||||
# Overwrite register timeout loop interval
|
||||
discard rdv.deletesRegister(1.seconds)
|
||||
|
||||
# Overwrite expiration times
|
||||
let now = Moment.now()
|
||||
for reg in rdv.registered.s.mitems:
|
||||
reg.expiration = now
|
||||
|
||||
# Wait for the deletion
|
||||
checkUntilTimeout:
|
||||
rdv.registered.offset == 10
|
||||
rdv.registered.s.len == 0
|
||||
(await peerRdvs[0].request(Opt.some(namespaceFoo))).len == 0
|
||||
(await peerRdvs[0].request(Opt.some(namespaceBar))).len == 0
|
||||
|
||||
# Advertise new peers
|
||||
await allFutures(peerRdvs[10 ..< 15].mapIt(it.advertise(namespaceFoo)))
|
||||
await allFutures(peerRdvs[15 ..< 20].mapIt(it.advertise(namespaceBar)))
|
||||
|
||||
check:
|
||||
rdv.registered.offset == 10
|
||||
rdv.registered.s.len == 10
|
||||
(await peerRdvs[0].request(Opt.some(namespaceFoo))).len == 5
|
||||
(await peerRdvs[0].request(Opt.some(namespaceBar))).len == 5
|
||||
|
||||
asyncTest "Cookie offset is reset to end (returns empty) then new peers are discoverable":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rdv) = setupRendezvousNodeWithPeerNodes(3)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
# Advertise two peers initially
|
||||
await allFutures(peerRdvs[0 ..< 2].mapIt(it.advertise(namespace)))
|
||||
|
||||
# Build and inject overflow cookie: offset past current high()+1
|
||||
let offset = (rdv.registered.high + 1000).uint64
|
||||
let cookie = buildProtobufCookie(offset, namespace)
|
||||
peerRdvs[0].injectCookieForPeer(rendezvousNode.peerInfo.peerId, namespace, cookie)
|
||||
|
||||
# First request should return empty due to clamping to high()+1
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 0
|
||||
|
||||
# Advertise a new peer, next request should return only the new one
|
||||
await peerRdvs[2].advertise(namespace)
|
||||
let peerRecords = await peerRdvs[0].request(Opt.some(namespace))
|
||||
check:
|
||||
peerRecords.len == 1
|
||||
peerRecords[0] == peerNodes[2].peerInfo.signedPeerRecord.data
|
||||
|
||||
asyncTest "Cookie offset is reset to low after flush (returns current entries)":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rdv) = setupRendezvousNodeWithPeerNodes(8)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
# Advertise 4 peers in namespace
|
||||
await allFutures(peerRdvs[0 ..< 4].mapIt(it.advertise(namespace)))
|
||||
|
||||
# Expire all and flush to advance registered.offset
|
||||
discard rdv.deletesRegister(1.seconds)
|
||||
let now = Moment.now()
|
||||
for reg in rdv.registered.s.mitems:
|
||||
reg.expiration = now
|
||||
|
||||
checkUntilTimeout:
|
||||
rdv.registered.s.len == 0
|
||||
rdv.registered.offset == 4
|
||||
|
||||
# Advertise 4 new peers
|
||||
await allFutures(peerRdvs[4 ..< 8].mapIt(it.advertise(namespace)))
|
||||
|
||||
# Build and inject underflow cookie: offset behind current low
|
||||
let offset = 0'u64
|
||||
let cookie = buildProtobufCookie(offset, namespace)
|
||||
peerRdvs[0].injectCookieForPeer(rendezvousNode.peerInfo.peerId, namespace, cookie)
|
||||
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 4
|
||||
|
||||
asyncTest "Cookie namespace mismatch resets to low (returns peers despite offset)":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rdv) = setupRendezvousNodeWithPeerNodes(3)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodesToRendezvousNode(peerNodes, rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
await allFutures(peerRdvs.mapIt(it.advertise(namespace)))
|
||||
|
||||
# Build and inject cookie with wrong namespace
|
||||
let offset = 10.uint64
|
||||
let cookie = buildProtobufCookie(offset, "other")
|
||||
peerRdvs[0].injectCookieForPeer(rendezvousNode.peerInfo.peerId, namespace, cookie)
|
||||
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 3
|
||||
|
||||
asyncTest "Peer default TTL is saved when advertised":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
|
||||
setupRendezvousNodeWithPeerNodes(1)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodes(peerNodes[0], rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
let timeBefore = Moment.now()
|
||||
await peerRdvs[0].advertise(namespace)
|
||||
let timeAfter = Moment.now()
|
||||
|
||||
# expiration within [timeBefore + 2hours, timeAfter + 2hours]
|
||||
check:
|
||||
# Peer Node side
|
||||
peerRdvs[0].registered.s[0].data.ttl.get == MinimumDuration.seconds.uint64
|
||||
peerRdvs[0].registered.s[0].expiration >= timeBefore + MinimumDuration
|
||||
peerRdvs[0].registered.s[0].expiration <= timeAfter + MinimumDuration
|
||||
# Rendezvous Node side
|
||||
rendezvousRdv.registered.s[0].data.ttl.get == MinimumDuration.seconds.uint64
|
||||
rendezvousRdv.registered.s[0].expiration >= timeBefore + MinimumDuration
|
||||
rendezvousRdv.registered.s[0].expiration <= timeAfter + MinimumDuration
|
||||
|
||||
asyncTest "Peer TTL is saved when advertised with TTL":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
|
||||
setupRendezvousNodeWithPeerNodes(1)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodes(peerNodes[0], rendezvousNode)
|
||||
|
||||
const
|
||||
namespace = "foo"
|
||||
ttl = 3.hours
|
||||
let timeBefore = Moment.now()
|
||||
await peerRdvs[0].advertise(namespace, ttl)
|
||||
let timeAfter = Moment.now()
|
||||
|
||||
# expiration within [timeBefore + ttl, timeAfter + ttl]
|
||||
check:
|
||||
# Peer Node side
|
||||
peerRdvs[0].registered.s[0].data.ttl.get == ttl.seconds.uint64
|
||||
peerRdvs[0].registered.s[0].expiration >= timeBefore + ttl
|
||||
peerRdvs[0].registered.s[0].expiration <= timeAfter + ttl
|
||||
# Rendezvous Node side
|
||||
rendezvousRdv.registered.s[0].data.ttl.get == ttl.seconds.uint64
|
||||
rendezvousRdv.registered.s[0].expiration >= timeBefore + ttl
|
||||
rendezvousRdv.registered.s[0].expiration <= timeAfter + ttl
|
||||
|
||||
asyncTest "Peer can reregister to update its TTL before previous TTL expires":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
|
||||
setupRendezvousNodeWithPeerNodes(1)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodes(peerNodes[0], rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
let now = Moment.now()
|
||||
|
||||
await peerRdvs[0].advertise(namespace)
|
||||
check:
|
||||
# Peer Node side
|
||||
peerRdvs[0].registered.s.len == 1
|
||||
peerRdvs[0].registered.s[0].expiration > now
|
||||
# Rendezvous Node side
|
||||
rendezvousRdv.registered.s.len == 1
|
||||
rendezvousRdv.registered.s[0].expiration > now
|
||||
|
||||
await peerRdvs[0].advertise(namespace, 5.hours)
|
||||
check:
|
||||
# Added 2nd registration
|
||||
# Updated expiration of the 1st one to the past
|
||||
# Will be deleted on deletion heartbeat
|
||||
# Peer Node side
|
||||
peerRdvs[0].registered.s.len == 2
|
||||
peerRdvs[0].registered.s[0].expiration < now
|
||||
# Rendezvous Node side
|
||||
rendezvousRdv.registered.s.len == 2
|
||||
rendezvousRdv.registered.s[0].expiration < now
|
||||
|
||||
# Returns only one record
|
||||
check (await peerRdvs[0].request(Opt.some(namespace))).len == 1
|
||||
|
||||
asyncTest "Peer registration is ignored if limit of 1000 registrations is reached":
|
||||
let (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv) =
|
||||
setupRendezvousNodeWithPeerNodes(1)
|
||||
(rendezvousNode & peerNodes).startAndDeferStop()
|
||||
|
||||
await connectNodes(peerNodes[0], rendezvousNode)
|
||||
|
||||
const namespace = "foo"
|
||||
let peerRdv = peerRdvs[0]
|
||||
|
||||
# Create 999 registrations
|
||||
await populatePeerRegistrations(
|
||||
peerRdv, rendezvousRdv, namespace, RegistrationLimitPerPeer - 1
|
||||
)
|
||||
|
||||
# 1000th registration allowed
|
||||
await peerRdv.advertise(namespace)
|
||||
check rendezvousRdv.registered.s.len == RegistrationLimitPerPeer
|
||||
|
||||
# 1001st registration ignored, limit reached
|
||||
await peerRdv.advertise(namespace)
|
||||
check rendezvousRdv.registered.s.len == RegistrationLimitPerPeer
|
||||
|
||||
asyncTest "Various local error":
|
||||
let rdv = RendezVous.new(minDuration = 1.minutes, maxDuration = 72.hours)
|
||||
expect AdvertiseError:
|
||||
discard await rdv.request(Opt.some("A".repeat(300)))
|
||||
expect AdvertiseError:
|
||||
discard await rdv.request(Opt.some("A"), -1)
|
||||
expect AdvertiseError:
|
||||
discard await rdv.request(Opt.some("A"), 3000)
|
||||
expect AdvertiseError:
|
||||
await rdv.advertise("A".repeat(300))
|
||||
expect AdvertiseError:
|
||||
await rdv.advertise("A", 73.hours)
|
||||
expect AdvertiseError:
|
||||
await rdv.advertise("A", 30.seconds)
|
||||
|
||||
test "Various config error":
|
||||
expect RendezVousError:
|
||||
discard RendezVous.new(minDuration = 30.seconds)
|
||||
expect RendezVousError:
|
||||
discard RendezVous.new(maxDuration = 73.hours)
|
||||
expect RendezVousError:
|
||||
discard RendezVous.new(minDuration = 15.minutes, maxDuration = 10.minutes)
|
||||
@@ -9,22 +9,11 @@
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import sequtils, strutils
|
||||
import chronos
|
||||
import ../libp2p/[protocols/rendezvous, switch, builders]
|
||||
import ../libp2p/discovery/[rendezvousinterface, discoverymngr]
|
||||
import ./helpers
|
||||
|
||||
proc createSwitch(rdv: RendezVous = RendezVous.new()): Switch =
|
||||
SwitchBuilder
|
||||
.new()
|
||||
.withRng(newRng())
|
||||
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
|
||||
.withTcpTransport()
|
||||
.withMplex()
|
||||
.withNoise()
|
||||
.withRendezVous(rdv)
|
||||
.build()
|
||||
import ../../libp2p/[protocols/rendezvous, switch, builders]
|
||||
import ../../libp2p/discovery/[rendezvousinterface, discoverymngr]
|
||||
import ../helpers
|
||||
import ./utils
|
||||
|
||||
type
|
||||
MockRendezVous = ref object of RendezVous
|
||||
@@ -33,7 +22,9 @@ type
|
||||
|
||||
MockErrorRendezVous = ref object of MockRendezVous
|
||||
|
||||
method advertise*(self: MockRendezVous, namespace: string, ttl: Duration) {.async.} =
|
||||
method advertise*(
|
||||
self: MockRendezVous, namespace: string, ttl: Duration
|
||||
) {.async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
if namespace == "ns1":
|
||||
self.numAdvertiseNs1 += 1
|
||||
elif namespace == "ns2":
|
||||
@@ -43,9 +34,9 @@ method advertise*(self: MockRendezVous, namespace: string, ttl: Duration) {.asyn
|
||||
|
||||
method advertise*(
|
||||
self: MockErrorRendezVous, namespace: string, ttl: Duration
|
||||
) {.async.} =
|
||||
) {.async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
await procCall MockRendezVous(self).advertise(namespace, ttl)
|
||||
raise newException(CatchableError, "MockErrorRendezVous.advertise")
|
||||
raise newException(AdvertiseError, "MockErrorRendezVous.advertise")
|
||||
|
||||
suite "RendezVous Interface":
|
||||
teardown:
|
||||
83
tests/discovery/utils.nim
Normal file
83
tests/discovery/utils.nim
Normal file
@@ -0,0 +1,83 @@
|
||||
import sequtils
|
||||
import chronos
|
||||
import ../../libp2p/[protobuf/minprotobuf, protocols/rendezvous, switch, builders]
|
||||
|
||||
proc createSwitch*(rdv: RendezVous = RendezVous.new()): Switch =
|
||||
SwitchBuilder
|
||||
.new()
|
||||
.withRng(newRng())
|
||||
.withAddresses(@[MultiAddress.init(MemoryAutoAddress).tryGet()])
|
||||
.withMemoryTransport()
|
||||
.withMplex()
|
||||
.withNoise()
|
||||
.withRendezVous(rdv)
|
||||
.build()
|
||||
|
||||
proc setupNodes*(count: int): (seq[Switch], seq[RendezVous]) =
|
||||
doAssert(count > 0, "Count must be greater than 0")
|
||||
|
||||
var
|
||||
nodes: seq[Switch] = @[]
|
||||
rdvs: seq[RendezVous] = @[]
|
||||
|
||||
for x in 0 ..< count:
|
||||
let rdv = RendezVous.new()
|
||||
let node = createSwitch(rdv)
|
||||
nodes.add(node)
|
||||
rdvs.add(rdv)
|
||||
|
||||
return (nodes, rdvs)
|
||||
|
||||
proc setupRendezvousNodeWithPeerNodes*(
|
||||
count: int
|
||||
): (Switch, seq[Switch], seq[RendezVous], RendezVous) =
|
||||
let
|
||||
(nodes, rdvs) = setupNodes(count + 1)
|
||||
rendezvousNode = nodes[0]
|
||||
rendezvousRdv = rdvs[0]
|
||||
peerNodes = nodes[1 ..^ 1]
|
||||
peerRdvs = rdvs[1 ..^ 1]
|
||||
|
||||
return (rendezvousNode, peerNodes, peerRdvs, rendezvousRdv)
|
||||
|
||||
template startAndDeferStop*(nodes: seq[Switch]) =
|
||||
await allFutures(nodes.mapIt(it.start()))
|
||||
defer:
|
||||
await allFutures(nodes.mapIt(it.stop()))
|
||||
|
||||
proc connectNodes*[T: Switch](dialer: T, target: T) {.async.} =
|
||||
await dialer.connect(target.peerInfo.peerId, target.peerInfo.addrs)
|
||||
|
||||
proc connectNodesToRendezvousNode*[T: Switch](
|
||||
nodes: seq[T], rendezvousNode: T
|
||||
) {.async.} =
|
||||
for node in nodes:
|
||||
await connectNodes(node, rendezvousNode)
|
||||
|
||||
proc buildProtobufCookie*(offset: uint64, namespace: string): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, offset)
|
||||
pb.write(2, namespace)
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
proc injectCookieForPeer*(
|
||||
rdv: RendezVous, peerId: PeerId, namespace: string, cookie: seq[byte]
|
||||
) =
|
||||
discard rdv.cookiesSaved.hasKeyOrPut(peerId, {namespace: cookie}.toTable())
|
||||
|
||||
proc populatePeerRegistrations*(
|
||||
peerRdv: RendezVous, targetRdv: RendezVous, namespace: string, count: int
|
||||
) {.async.} =
|
||||
# Test helper: quickly populate many registrations for a peer.
|
||||
# We first create a single real registration, then clone that record
|
||||
# directly into the rendezvous registry to reach the desired count fast.
|
||||
#
|
||||
# Notes:
|
||||
# - Calling advertise() concurrently results in bufferstream defect.
|
||||
# - Calling advertise() sequentially is too slow for large counts.
|
||||
await peerRdv.advertise(namespace)
|
||||
|
||||
let record = targetRdv.registered.s[0]
|
||||
for i in 0 ..< count - 1:
|
||||
targetRdv.registered.s.add(record)
|
||||
@@ -4,10 +4,12 @@ import std/enumerate
|
||||
import chronos
|
||||
import ../../libp2p/[switch, builders]
|
||||
import ../../libp2p/protocols/kademlia
|
||||
import ../../libp2p/protocols/kademlia/kademlia
|
||||
import ../../libp2p/protocols/kademlia/routingtable
|
||||
import ../../libp2p/protocols/kademlia/keys
|
||||
import unittest2
|
||||
import ../utils/async_tests
|
||||
import ./utils.nim
|
||||
import ../helpers
|
||||
|
||||
proc createSwitch(): Switch =
|
||||
@@ -31,14 +33,15 @@ proc countBucketEntries(buckets: seq[Bucket], key: Key): uint32 =
|
||||
suite "KadDHT - FindNode":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
asyncTest "Simple find peer":
|
||||
|
||||
asyncTest "Simple find node":
|
||||
let swarmSize = 3
|
||||
var switches: seq[Switch]
|
||||
var kads: seq[KadDHT]
|
||||
# every node needs a switch, and an assosciated kad mounted to it
|
||||
for i in 0 ..< swarmSize:
|
||||
switches.add(createSwitch())
|
||||
kads.add(KadDHT.new(switches[i]))
|
||||
kads.add(KadDHT.new(switches[i], PermissiveValidator(), CandSelector()))
|
||||
switches[i].mount(kads[i])
|
||||
|
||||
# Once the the creation/mounting of switches are done, we can start
|
||||
@@ -79,24 +82,24 @@ suite "KadDHT - FindNode":
|
||||
)
|
||||
await switches.mapIt(it.stop()).allFutures()
|
||||
|
||||
asyncTest "Relay find peer":
|
||||
asyncTest "Relay find node":
|
||||
let parentSwitch = createSwitch()
|
||||
let parentKad = KadDHT.new(parentSwitch)
|
||||
let parentKad = KadDHT.new(parentSwitch, PermissiveValidator(), CandSelector())
|
||||
parentSwitch.mount(parentKad)
|
||||
await parentSwitch.start()
|
||||
|
||||
let broSwitch = createSwitch()
|
||||
let broKad = KadDHT.new(broSwitch)
|
||||
let broKad = KadDHT.new(broSwitch, PermissiveValidator(), CandSelector())
|
||||
broSwitch.mount(broKad)
|
||||
await broSwitch.start()
|
||||
|
||||
let sisSwitch = createSwitch()
|
||||
let sisKad = KadDHT.new(sisSwitch)
|
||||
let sisKad = KadDHT.new(sisSwitch, PermissiveValidator(), CandSelector())
|
||||
sisSwitch.mount(sisKad)
|
||||
await sisSwitch.start()
|
||||
|
||||
let neiceSwitch = createSwitch()
|
||||
let neiceKad = KadDHT.new(neiceSwitch)
|
||||
let neiceKad = KadDHT.new(neiceSwitch, PermissiveValidator(), CandSelector())
|
||||
neiceSwitch.mount(neiceKad)
|
||||
await neiceSwitch.start()
|
||||
|
||||
@@ -142,3 +145,33 @@ suite "KadDHT - FindNode":
|
||||
await broSwitch.stop()
|
||||
await sisSwitch.stop()
|
||||
await neiceSwitch.stop()
|
||||
|
||||
asyncTest "Find peer":
|
||||
let aliceSwitch = createSwitch()
|
||||
let aliceKad = KadDHT.new(aliceSwitch, PermissiveValidator(), CandSelector())
|
||||
aliceSwitch.mount(aliceKad)
|
||||
await aliceSwitch.start()
|
||||
|
||||
let bobSwitch = createSwitch()
|
||||
let bobKad = KadDHT.new(bobSwitch, PermissiveValidator(), CandSelector())
|
||||
bobSwitch.mount(bobKad)
|
||||
await bobSwitch.start()
|
||||
|
||||
let charlieSwitch = createSwitch()
|
||||
let charlieKad = KadDHT.new(charlieSwitch, PermissiveValidator(), CandSelector())
|
||||
charlieSwitch.mount(charlieKad)
|
||||
await charlieSwitch.start()
|
||||
|
||||
await bobKad.bootstrap(@[aliceSwitch.peerInfo])
|
||||
await charlieKad.bootstrap(@[aliceSwitch.peerInfo])
|
||||
|
||||
let peerInfoRes = await bobKad.findPeer(charlieSwitch.peerInfo.peerId)
|
||||
doAssert peerInfoRes.isOk
|
||||
doAssert peerInfoRes.get().peerId == charlieSwitch.peerInfo.peerId
|
||||
|
||||
let peerInfoRes2 = await bobKad.findPeer(PeerId.random(newRng()).get())
|
||||
doAssert peerInfoRes2.isErr
|
||||
|
||||
await aliceSwitch.stop()
|
||||
await bobSwitch.stop()
|
||||
await charlieSwitch.stop()
|
||||
|
||||
155
tests/kademlia/testputval.nim
Normal file
155
tests/kademlia/testputval.nim
Normal file
@@ -0,0 +1,155 @@
|
||||
{.used.}
|
||||
import chronicles
|
||||
import strformat
|
||||
# import sequtils
|
||||
import options
|
||||
import std/[times]
|
||||
# import std/enumerate
|
||||
import chronos
|
||||
import ../../libp2p/[switch, builders]
|
||||
import ../../libp2p/protocols/kademlia
|
||||
import ../../libp2p/protocols/kademlia/kademlia
|
||||
import ../../libp2p/protocols/kademlia/routingtable
|
||||
import ../../libp2p/protocols/kademlia/keys
|
||||
import unittest2
|
||||
import ../utils/async_tests
|
||||
import ./utils.nim
|
||||
import std/tables
|
||||
import ../helpers
|
||||
|
||||
proc createSwitch(): Switch =
|
||||
SwitchBuilder
|
||||
.new()
|
||||
.withRng(newRng())
|
||||
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
|
||||
.withTcpTransport()
|
||||
.withMplex()
|
||||
.withNoise()
|
||||
.build()
|
||||
|
||||
proc countBucketEntries(buckets: seq[Bucket], key: Key): uint32 =
|
||||
var res: uint32 = 0
|
||||
for b in buckets:
|
||||
for ent in b.peers:
|
||||
if ent.nodeId == key:
|
||||
res += 1
|
||||
return res
|
||||
|
||||
suite "KadDHT - PutVal":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
asyncTest "Simple put":
|
||||
let switch1 = createSwitch()
|
||||
let switch2 = createSwitch()
|
||||
var kad1 = KadDHT.new(switch1, PermissiveValidator(), CandSelector())
|
||||
var kad2 = KadDHT.new(switch2, PermissiveValidator(), CandSelector())
|
||||
switch1.mount(kad1)
|
||||
switch2.mount(kad2)
|
||||
|
||||
await allFutures(switch1.start(), switch2.start())
|
||||
defer:
|
||||
await allFutures(switch1.stop(), switch2.stop())
|
||||
|
||||
await kad2.bootstrap(@[switch1.peerInfo])
|
||||
|
||||
discard await kad1.findNode(kad2.rtable.selfId)
|
||||
discard await kad2.findNode(kad1.rtable.selfId)
|
||||
|
||||
doAssert(len(kad1.dataTable.entries) == 0)
|
||||
doAssert(len(kad2.dataTable.entries) == 0)
|
||||
let puttedData = kad1.rtable.selfId.getBytes()
|
||||
let entryKey = EntryKey.init(puttedData)
|
||||
let entryVal = EntryValue.init(puttedData)
|
||||
discard await kad2.putValue(entryKey, entryVal, some(1))
|
||||
|
||||
let entered1: EntryValue = kad1.dataTable.entries[entryKey].value
|
||||
let entered2: EntryValue = kad2.dataTable.entries[entryKey].value
|
||||
|
||||
var ents = kad1.dataTable.entries
|
||||
doAssert(entered1.data == entryVal.data, fmt"table: {ents}, putted: {entryVal}")
|
||||
doAssert(len(kad1.dataTable.entries) == 1)
|
||||
|
||||
ents = kad2.dataTable.entries
|
||||
doAssert(entered2.data == entryVal.data, fmt"table: {ents}, putted: {entryVal}")
|
||||
doAssert(len(kad2.dataTable.entries) == 1)
|
||||
|
||||
asyncTest "Change Validator":
|
||||
let switch1 = createSwitch()
|
||||
let switch2 = createSwitch()
|
||||
var kad1 = KadDHT.new(switch1, RestrictiveValidator(), CandSelector())
|
||||
var kad2 = KadDHT.new(switch2, RestrictiveValidator(), CandSelector())
|
||||
switch1.mount(kad1)
|
||||
switch2.mount(kad2)
|
||||
|
||||
await allFutures(switch1.start(), switch2.start())
|
||||
defer:
|
||||
await allFutures(switch1.stop(), switch2.stop())
|
||||
|
||||
await kad2.bootstrap(@[switch1.peerInfo])
|
||||
doAssert(len(kad1.dataTable.entries) == 0)
|
||||
let puttedData = kad1.rtable.selfId.getBytes()
|
||||
let entryVal = EntryValue.init(puttedData)
|
||||
let entryKey = EntryKey.init(puttedData)
|
||||
discard await kad2.putValue(entryKey, entryVal, some(1))
|
||||
doAssert(len(kad1.dataTable.entries) == 0, fmt"content: {kad1.dataTable.entries}")
|
||||
kad1.setValidator(PermissiveValidator())
|
||||
discard await kad2.putValue(entryKey, entryVal, some(1))
|
||||
|
||||
doAssert(len(kad1.dataTable.entries) == 0, fmt"{kad1.dataTable.entries}")
|
||||
kad2.setValidator(PermissiveValidator())
|
||||
discard await kad2.putValue(entryKey, entryVal, some(1))
|
||||
doAssert(len(kad1.dataTable.entries) == 1, fmt"{kad1.dataTable.entries}")
|
||||
|
||||
asyncTest "Good Time":
|
||||
let switch1 = createSwitch()
|
||||
let switch2 = createSwitch()
|
||||
var kad1 = KadDHT.new(switch1, PermissiveValidator(), CandSelector())
|
||||
var kad2 = KadDHT.new(switch2, PermissiveValidator(), CandSelector())
|
||||
switch1.mount(kad1)
|
||||
switch2.mount(kad2)
|
||||
await allFutures(switch1.start(), switch2.start())
|
||||
defer:
|
||||
await allFutures(switch1.stop(), switch2.stop())
|
||||
await kad2.bootstrap(@[switch1.peerInfo])
|
||||
|
||||
let puttedData = kad1.rtable.selfId.getBytes()
|
||||
let entryVal = EntryValue.init(puttedData)
|
||||
let entryKey = EntryKey.init(puttedData)
|
||||
discard await kad2.putValue(entryKey, entryVal, some(1))
|
||||
|
||||
let time: string = kad1.dataTable.entries[entryKey].time.ts
|
||||
|
||||
let now = times.now().utc
|
||||
let parsed = time.parse(initTimeFormat("yyyy-MM-dd'T'HH:mm:ss'Z'"), utc())
|
||||
|
||||
# get the diff between the stringified-parsed and the direct "now"
|
||||
let elapsed = (now - parsed)
|
||||
doAssert(elapsed < times.initDuration(seconds = 2))
|
||||
|
||||
asyncTest "Reselect":
|
||||
let switch1 = createSwitch()
|
||||
let switch2 = createSwitch()
|
||||
var kad1 = KadDHT.new(switch1, PermissiveValidator(), OthersSelector())
|
||||
var kad2 = KadDHT.new(switch2, PermissiveValidator(), OthersSelector())
|
||||
switch1.mount(kad1)
|
||||
switch2.mount(kad2)
|
||||
await allFutures(switch1.start(), switch2.start())
|
||||
defer:
|
||||
await allFutures(switch1.stop(), switch2.stop())
|
||||
await kad2.bootstrap(@[switch1.peerInfo])
|
||||
|
||||
let puttedData = kad1.rtable.selfId.getBytes()
|
||||
let entryVal = EntryValue.init(puttedData)
|
||||
let entryKey = EntryKey.init(puttedData)
|
||||
discard await kad1.putValue(entryKey, entryVal, some(1))
|
||||
doAssert(len(kad2.dataTable.entries) == 1, fmt"{kad1.dataTable.entries}")
|
||||
doAssert(kad2.dataTable.entries[entryKey].value.data == entryVal.data)
|
||||
discard await kad1.putValue(entryKey, EntryValue.init(@[]), some(1))
|
||||
doAssert(kad2.dataTable.entries[entryKey].value.data == entryVal.data)
|
||||
kad2.setSelector(CandSelector())
|
||||
kad1.setSelector(CandSelector())
|
||||
discard await kad1.putValue(entryKey, EntryValue.init(@[]), some(1))
|
||||
doAssert(
|
||||
kad2.dataTable.entries[entryKey].value == EntryValue.init(@[]),
|
||||
fmt"{kad2.dataTable.entries}",
|
||||
)
|
||||
@@ -12,23 +12,24 @@
|
||||
import unittest
|
||||
import chronos
|
||||
import ../../libp2p/crypto/crypto
|
||||
import ../../libp2p/protocols/kademlia/[routingtable, consts, keys]
|
||||
import ../../libp2p/protocols/kademlia/[xordistance, routingtable, consts, keys]
|
||||
import results
|
||||
|
||||
proc testKey*(x: byte): Key =
|
||||
var buf: array[IdLength, byte]
|
||||
buf[31] = x
|
||||
return Key(kind: KeyType.Unhashed, data: buf)
|
||||
return Key(kind: KeyType.Raw, data: @buf)
|
||||
|
||||
let rng = crypto.newRng()
|
||||
|
||||
suite "routing table":
|
||||
test "inserts single key in correct bucket":
|
||||
let selfId = testKey(0)
|
||||
var rt = RoutingTable.init(selfId)
|
||||
var rt = RoutingTable.init(selfId, Opt.none(XorDHasher))
|
||||
let other = testKey(0b10000000)
|
||||
discard rt.insert(other)
|
||||
|
||||
let idx = bucketIndex(selfId, other)
|
||||
let idx = bucketIndex(selfId, other, Opt.none(XorDHasher))
|
||||
check:
|
||||
rt.buckets.len > idx
|
||||
rt.buckets[idx].peers.len == 1
|
||||
@@ -36,12 +37,11 @@ suite "routing table":
|
||||
|
||||
test "does not insert beyond capacity":
|
||||
let selfId = testKey(0)
|
||||
var rt = RoutingTable.init(selfId)
|
||||
var rt = RoutingTable.init(selfId, Opt.some(noOpHasher))
|
||||
let targetBucket = 6
|
||||
for _ in 0 ..< DefaultReplic + 5:
|
||||
var kid = randomKeyInBucketRange(selfId, targetBucket, rng)
|
||||
kid.kind = KeyType.Unhashed
|
||||
# Overriding so we don't use sha for comparing xor distances
|
||||
kid.kind = KeyType.Raw # Overriding so we don't use sha for comparing xor distances
|
||||
discard rt.insert(kid)
|
||||
|
||||
check targetBucket < rt.buckets.len
|
||||
@@ -50,7 +50,7 @@ suite "routing table":
|
||||
|
||||
test "findClosest returns sorted keys":
|
||||
let selfId = testKey(0)
|
||||
var rt = RoutingTable.init(selfId)
|
||||
var rt = RoutingTable.init(selfId, Opt.some(noOpHasher))
|
||||
let ids = @[testKey(1), testKey(2), testKey(3), testKey(4), testKey(5)]
|
||||
for id in ids:
|
||||
discard rt.insert(id)
|
||||
@@ -75,9 +75,8 @@ suite "routing table":
|
||||
let selfId = testKey(0)
|
||||
let targetBucket = 3
|
||||
var rid = randomKeyInBucketRange(selfId, targetBucket, rng)
|
||||
rid.kind = KeyType.Unhashed
|
||||
# Overriding so we don't use sha for comparing xor distances
|
||||
let idx = bucketIndex(selfId, rid)
|
||||
rid.kind = KeyType.Raw # Overriding so we don't use sha for comparing xor distances
|
||||
let idx = bucketIndex(selfId, rid, Opt.some(noOpHasher))
|
||||
check:
|
||||
idx == targetBucket
|
||||
rid != selfId
|
||||
|
||||
27
tests/kademlia/utils.nim
Normal file
27
tests/kademlia/utils.nim
Normal file
@@ -0,0 +1,27 @@
|
||||
{.used.}
|
||||
import results
|
||||
import ../../libp2p/protocols/kademlia/kademlia
|
||||
|
||||
type PermissiveValidator* = ref object of EntryValidator
|
||||
method isValid*(self: PermissiveValidator, key: EntryKey, val: EntryValue): bool =
|
||||
true
|
||||
|
||||
type RestrictiveValidator* = ref object of EntryValidator
|
||||
method isValid(self: RestrictiveValidator, key: EntryKey, val: EntryValue): bool =
|
||||
false
|
||||
|
||||
type CandSelector* = ref object of EntrySelector
|
||||
method select*(
|
||||
self: CandSelector, cand: EntryRecord, others: seq[EntryRecord]
|
||||
): Result[EntryRecord, string] =
|
||||
return ok(cand)
|
||||
|
||||
type OthersSelector* = ref object of EntrySelector
|
||||
method select*(
|
||||
self: OthersSelector, cand: EntryRecord, others: seq[EntryRecord]
|
||||
): Result[EntryRecord, string] =
|
||||
return
|
||||
if others.len == 0:
|
||||
ok(cand)
|
||||
else:
|
||||
ok(others[0])
|
||||
@@ -243,10 +243,14 @@ suite "GossipSub Integration - Control Messages":
|
||||
# When an IHAVE message is sent
|
||||
let p1 = n0.getOrCreatePeer(n1.peerInfo.peerId, @[GossipSubCodec_12])
|
||||
n0.broadcast(@[p1], RPCMsg(control: some(ihaveMessage)), isHighPriority = false)
|
||||
await waitForHeartbeat()
|
||||
|
||||
# Then the peer has the message ID
|
||||
# Wait until IHAVE response is received
|
||||
checkUntilTimeout:
|
||||
receivedIHaves[].len == 1
|
||||
|
||||
# Then the peer has exactly one IHAVE message with the correct message ID
|
||||
check:
|
||||
receivedIHaves[].len == 1
|
||||
receivedIHaves[0] == ControlIHave(topicID: topic, messageIDs: @[messageID])
|
||||
|
||||
asyncTest "IWANT messages correctly request messages by their IDs":
|
||||
@@ -281,10 +285,14 @@ suite "GossipSub Integration - Control Messages":
|
||||
# When an IWANT message is sent
|
||||
let p1 = n0.getOrCreatePeer(n1.peerInfo.peerId, @[GossipSubCodec_12])
|
||||
n0.broadcast(@[p1], RPCMsg(control: some(iwantMessage)), isHighPriority = false)
|
||||
await waitForHeartbeat()
|
||||
|
||||
# Then the peer has the message ID
|
||||
# Wait until IWANT response is received
|
||||
checkUntilTimeout:
|
||||
receivedIWants[].len == 1
|
||||
|
||||
# Then the peer has exactly one IWANT message with the correct message ID
|
||||
check:
|
||||
receivedIWants[].len == 1
|
||||
receivedIWants[0] == ControlIWant(messageIDs: @[messageID])
|
||||
|
||||
asyncTest "IHAVE for message not held by peer triggers IWANT response to sender":
|
||||
@@ -316,10 +324,14 @@ suite "GossipSub Integration - Control Messages":
|
||||
# When an IHAVE message is sent from node0
|
||||
let p1 = n0.getOrCreatePeer(n1.peerInfo.peerId, @[GossipSubCodec_12])
|
||||
n0.broadcast(@[p1], RPCMsg(control: some(ihaveMessage)), isHighPriority = false)
|
||||
await waitForHeartbeat()
|
||||
|
||||
# Then node0 should receive an IWANT message from node1 (as node1 doesn't have the message)
|
||||
# Wait until IWANT response is received
|
||||
checkUntilTimeout:
|
||||
receivedIWants[].len == 1
|
||||
|
||||
# Then node0 should receive exactly one IWANT message from node1
|
||||
check:
|
||||
receivedIWants[].len == 1
|
||||
receivedIWants[0] == ControlIWant(messageIDs: @[messageID])
|
||||
|
||||
asyncTest "IDONTWANT":
|
||||
|
||||
@@ -334,14 +334,6 @@ suite "GossipSub Integration - Mesh Management":
|
||||
# When DValues of Node0 are updated back to the initial dValues
|
||||
node0.parameters.applyDValues(dValues)
|
||||
|
||||
# Waiting more than one heartbeat (60ms) and less than pruneBackoff (1s)
|
||||
await sleepAsync(pruneBackoff.div(2))
|
||||
check:
|
||||
node0.mesh.getOrDefault(topic).len == newDValues.get.dHigh.get
|
||||
|
||||
# When pruneBackoff period is done
|
||||
await sleepAsync(pruneBackoff)
|
||||
|
||||
# Then on the next heartbeat mesh is rebalanced and peers are regrafted to the initial d value
|
||||
check:
|
||||
checkUntilTimeout:
|
||||
node0.mesh.getOrDefault(topic).len == dValues.get.d.get
|
||||
|
||||
@@ -76,33 +76,6 @@ suite "GossipSub Integration - Message Handling":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
asyncTest "Split IWANT replies when individual messages are below maxSize but combined exceed maxSize":
|
||||
# This test checks if two messages, each below the maxSize, are correctly split when their combined size exceeds maxSize.
|
||||
# Expected: Both messages should be received.
|
||||
let (gossip0, gossip1, receivedMessages) = await setupTest()
|
||||
|
||||
let messageSize = gossip1.maxMessageSize div 2 + 1
|
||||
let (iwantMessageIds, sentMessages) =
|
||||
createMessages(gossip0, gossip1, messageSize, messageSize)
|
||||
|
||||
gossip1.broadcast(
|
||||
gossip1.mesh["foobar"],
|
||||
RPCMsg(
|
||||
control: some(
|
||||
ControlMessage(
|
||||
ihave: @[ControlIHave(topicID: "foobar", messageIDs: iwantMessageIds)]
|
||||
)
|
||||
)
|
||||
),
|
||||
isHighPriority = false,
|
||||
)
|
||||
|
||||
checkUntilTimeout:
|
||||
receivedMessages[] == sentMessages
|
||||
check receivedMessages[].len == 2
|
||||
|
||||
await teardownTest(gossip0, gossip1)
|
||||
|
||||
asyncTest "Discard IWANT replies when both messages individually exceed maxSize":
|
||||
# This test checks if two messages, each exceeding the maxSize, are discarded and not sent.
|
||||
# Expected: No messages should be received.
|
||||
@@ -157,41 +130,6 @@ suite "GossipSub Integration - Message Handling":
|
||||
|
||||
await teardownTest(gossip0, gossip1)
|
||||
|
||||
asyncTest "Split IWANT replies when one message is below maxSize and the other exceeds maxSize":
|
||||
# This test checks if, when given two messages where one is below maxSize and the other exceeds it, only the smaller message is processed and sent.
|
||||
# Expected: Only the smaller message should be received.
|
||||
let (gossip0, gossip1, receivedMessages) = await setupTest()
|
||||
let maxSize = gossip1.maxMessageSize
|
||||
let size1 = maxSize div 2
|
||||
let size2 = maxSize + 10
|
||||
let (bigIWantMessageIds, sentMessages) =
|
||||
createMessages(gossip0, gossip1, size1, size2)
|
||||
|
||||
gossip1.broadcast(
|
||||
gossip1.mesh["foobar"],
|
||||
RPCMsg(
|
||||
control: some(
|
||||
ControlMessage(
|
||||
ihave: @[ControlIHave(topicID: "foobar", messageIDs: bigIWantMessageIds)]
|
||||
)
|
||||
)
|
||||
),
|
||||
isHighPriority = false,
|
||||
)
|
||||
|
||||
var smallestSet: HashSet[seq[byte]]
|
||||
let seqs = toSeq(sentMessages)
|
||||
if seqs[0] < seqs[1]:
|
||||
smallestSet.incl(seqs[0])
|
||||
else:
|
||||
smallestSet.incl(seqs[1])
|
||||
|
||||
checkUntilTimeout:
|
||||
receivedMessages[] == smallestSet
|
||||
check receivedMessages[].len == 1
|
||||
|
||||
await teardownTest(gossip0, gossip1)
|
||||
|
||||
asyncTest "messages are not sent back to source or forwarding peer":
|
||||
let
|
||||
numberOfNodes = 3
|
||||
|
||||
@@ -27,6 +27,28 @@ suite "Message":
|
||||
|
||||
check verify(msg)
|
||||
|
||||
test "signature with missing key":
|
||||
let
|
||||
seqno = 11'u64
|
||||
seckey = PrivateKey.random(Ed25519, rng[]).get()
|
||||
pubkey = seckey.getPublicKey().get()
|
||||
peer = PeerInfo.new(seckey)
|
||||
check peer.peerId.hasPublicKey() == true
|
||||
var msg = Message.init(some(peer), @[], "topic", some(seqno), sign = true)
|
||||
msg.key = @[]
|
||||
# get the key from fromPeer field (inlined)
|
||||
check verify(msg)
|
||||
|
||||
test "signature without inlined pubkey in peerId":
|
||||
let
|
||||
seqno = 11'u64
|
||||
peer = PeerInfo.new(PrivateKey.random(RSA, rng[]).get())
|
||||
var msg = Message.init(some(peer), @[], "topic", some(seqno), sign = true)
|
||||
msg.key = @[]
|
||||
# shouldn't work since there's no key field
|
||||
# and the key is not inlined in peerid (too large)
|
||||
check verify(msg) == false
|
||||
|
||||
test "defaultMsgIdProvider success":
|
||||
let
|
||||
seqno = 11'u64
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
import chronos
|
||||
import
|
||||
../../libp2p/[protocols/connectivity/autonat/client, peerid, multiaddress, switch]
|
||||
from ../../libp2p/protocols/connectivity/autonat/core import
|
||||
from ../../libp2p/protocols/connectivity/autonat/types import
|
||||
NetworkReachability, AutonatUnreachableError, AutonatError
|
||||
|
||||
type
|
||||
|
||||
160
tests/testautonatv2.nim
Normal file
160
tests/testautonatv2.nim
Normal file
@@ -0,0 +1,160 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/options
|
||||
import chronos
|
||||
import
|
||||
../libp2p/[
|
||||
switch,
|
||||
transports/tcptransport,
|
||||
upgrademngrs/upgrade,
|
||||
builders,
|
||||
protocols/connectivity/autonatv2/types,
|
||||
protocols/connectivity/autonatv2/utils,
|
||||
],
|
||||
./helpers
|
||||
|
||||
proc checkEncodeDecode[T](msg: T) =
|
||||
# this would be equivalent of doing the following (e.g. for DialBack)
|
||||
# check msg == DialBack.decode(msg.encode()).get()
|
||||
check msg == T.decode(msg.encode()).get()
|
||||
|
||||
proc newAutonatV2ServerSwitch(): Switch =
|
||||
var builder = newStandardSwitchBuilder().withAutonatV2()
|
||||
return builder.build()
|
||||
|
||||
suite "AutonatV2":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
asyncTest "encode/decode messages":
|
||||
# DialRequest
|
||||
checkEncodeDecode(
|
||||
DialRequest(
|
||||
addrs:
|
||||
@[
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4041").get(),
|
||||
],
|
||||
nonce: 42,
|
||||
)
|
||||
)
|
||||
|
||||
# DialResponse
|
||||
checkEncodeDecode(
|
||||
DialResponse(
|
||||
status: ResponseStatus.Ok,
|
||||
addrIdx: Opt.some(1.uint32),
|
||||
dialStatus: Opt.some(DialStatus.Ok),
|
||||
)
|
||||
)
|
||||
|
||||
# DialDataRequest
|
||||
checkEncodeDecode(DialDataRequest(addrIdx: 42, numBytes: 128))
|
||||
|
||||
# DialDataResponse
|
||||
checkEncodeDecode(DialDataResponse(data: @[1'u8, 2, 3, 4, 5]))
|
||||
|
||||
# AutonatV2Msg - DialRequest
|
||||
checkEncodeDecode(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialRequest,
|
||||
dialReq: DialRequest(
|
||||
addrs:
|
||||
@[
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4041").get(),
|
||||
],
|
||||
nonce: 42,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# AutonatV2Msg - DialResponse
|
||||
checkEncodeDecode(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialResponse,
|
||||
dialResp: DialResponse(
|
||||
status: ResponseStatus.Ok,
|
||||
addrIdx: Opt.some(1.uint32),
|
||||
dialStatus: Opt.some(DialStatus.Ok),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# AutonatV2Msg - DialDataRequest
|
||||
checkEncodeDecode(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialDataRequest,
|
||||
dialDataReq: DialDataRequest(addrIdx: 42, numBytes: 128),
|
||||
)
|
||||
)
|
||||
|
||||
# AutonatV2Msg - DialDataResponse
|
||||
checkEncodeDecode(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialDataResponse,
|
||||
dialDataResp: DialDataResponse(data: @[1'u8, 2, 3, 4, 5]),
|
||||
)
|
||||
)
|
||||
|
||||
# DialBack
|
||||
checkEncodeDecode(DialBack(nonce: 123456))
|
||||
|
||||
# DialBackResponse
|
||||
checkEncodeDecode(DialBackResponse(status: DialBackStatus.Ok))
|
||||
|
||||
asyncTest "asNetworkReachability":
|
||||
check asNetworkReachability(DialResponse(status: EInternalError)) == Unknown
|
||||
check asNetworkReachability(DialResponse(status: ERequestRejected)) == Unknown
|
||||
check asNetworkReachability(DialResponse(status: EDialRefused)) == Unknown
|
||||
check asNetworkReachability(
|
||||
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.none(DialStatus))
|
||||
) == Unknown
|
||||
check asNetworkReachability(
|
||||
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(Unused))
|
||||
) == Unknown
|
||||
check asNetworkReachability(
|
||||
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(EDialError))
|
||||
) == NotReachable
|
||||
check asNetworkReachability(
|
||||
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(EDialBackError))
|
||||
) == NotReachable
|
||||
check asNetworkReachability(
|
||||
DialResponse(status: ResponseStatus.Ok, dialStatus: Opt.some(DialStatus.Ok))
|
||||
) == Reachable
|
||||
|
||||
asyncTest "asAutonatV2Response":
|
||||
let addrs = @[MultiAddress.init("/ip4/127.0.0.1/tcp/4000").get()]
|
||||
let errorDialResp = DialResponse(
|
||||
status: ResponseStatus.Ok,
|
||||
addrIdx: Opt.none(AddrIdx),
|
||||
dialStatus: Opt.none(DialStatus),
|
||||
)
|
||||
check asAutonatV2Response(errorDialResp, addrs) ==
|
||||
AutonatV2Response(
|
||||
reachability: Unknown, dialResp: errorDialResp, addrs: Opt.none(MultiAddress)
|
||||
)
|
||||
|
||||
let correctDialResp = DialResponse(
|
||||
status: ResponseStatus.Ok,
|
||||
addrIdx: Opt.some(0.AddrIdx),
|
||||
dialStatus: Opt.some(DialStatus.Ok),
|
||||
)
|
||||
check asAutonatV2Response(correctDialResp, addrs) ==
|
||||
AutonatV2Response(
|
||||
reachability: Reachable, dialResp: correctDialResp, addrs: Opt.some(addrs[0])
|
||||
)
|
||||
|
||||
asyncTest "Instanciate server":
|
||||
let serverSwitch = newAutonatV2ServerSwitch()
|
||||
await serverSwitch.start()
|
||||
await serverSwitch.stop()
|
||||
@@ -17,7 +17,7 @@ import ../libp2p/[connmanager, stream/connection, crypto/crypto, muxers/muxer, p
|
||||
import helpers
|
||||
|
||||
proc getMuxer(peerId: PeerId, dir: Direction = Direction.In): Muxer =
|
||||
return Muxer(connection: Connection.new(peerId, dir, Opt.none(MultiAddress)))
|
||||
return Muxer(connection: Connection.new(peerId, dir))
|
||||
|
||||
type TestMuxer = ref object of Muxer
|
||||
peerId: PeerId
|
||||
@@ -25,7 +25,7 @@ type TestMuxer = ref object of Muxer
|
||||
method newStream*(
|
||||
m: TestMuxer, name: string = "", lazy: bool = false
|
||||
): Future[Connection] {.async: (raises: [CancelledError, LPStreamError, MuxerError]).} =
|
||||
Connection.new(m.peerId, Direction.Out, Opt.none(MultiAddress))
|
||||
Connection.new(m.peerId, Direction.Out)
|
||||
|
||||
suite "Connection Manager":
|
||||
teardown:
|
||||
@@ -124,7 +124,7 @@ suite "Connection Manager":
|
||||
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
|
||||
|
||||
let muxer = new TestMuxer
|
||||
let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress))
|
||||
let connection = Connection.new(peerId, Direction.In)
|
||||
muxer.peerId = peerId
|
||||
muxer.connection = connection
|
||||
|
||||
@@ -144,7 +144,7 @@ suite "Connection Manager":
|
||||
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
|
||||
|
||||
let muxer = new TestMuxer
|
||||
let connection = Connection.new(peerId, Direction.In, Opt.none(MultiAddress))
|
||||
let connection = Connection.new(peerId, Direction.In)
|
||||
muxer.peerId = peerId
|
||||
muxer.connection = connection
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import unittest2
|
||||
|
||||
import ../libp2p/protocols/connectivity/dcutr/core as dcore
|
||||
import ../libp2p/protocols/connectivity/dcutr/[client, server]
|
||||
from ../libp2p/protocols/connectivity/autonat/core import NetworkReachability
|
||||
from ../libp2p/protocols/connectivity/autonat/types import NetworkReachability
|
||||
import ../libp2p/builders
|
||||
import ../libp2p/utils/future
|
||||
import ./helpers
|
||||
|
||||
55
tests/testipaddr.nim
Normal file
55
tests/testipaddr.nim
Normal file
@@ -0,0 +1,55 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/options
|
||||
import chronos
|
||||
import ../libp2p/[utils/ipaddr], ./helpers
|
||||
|
||||
suite "IpAddr Utils":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
test "ipAddrMatches":
|
||||
# same ip address
|
||||
check ipAddrMatches(
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4041").get(),
|
||||
@[MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get()],
|
||||
)
|
||||
# different ip address
|
||||
check not ipAddrMatches(
|
||||
MultiAddress.init("/ip4/127.0.0.2/tcp/4041").get(),
|
||||
@[MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get()],
|
||||
)
|
||||
|
||||
test "ipSupport":
|
||||
check ipSupport(@[MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get()]) ==
|
||||
(true, false)
|
||||
check ipSupport(@[MultiAddress.init("/ip6/::1/tcp/4040").get()]) == (false, true)
|
||||
check ipSupport(
|
||||
@[
|
||||
MultiAddress.init("/ip6/::1/tcp/4040").get(),
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
|
||||
]
|
||||
) == (true, true)
|
||||
|
||||
test "isPrivate, isPublic":
|
||||
check isPrivate("192.168.1.100")
|
||||
check not isPublic("192.168.1.100")
|
||||
check isPrivate("10.0.0.25")
|
||||
check not isPublic("10.0.0.25")
|
||||
check isPrivate("169.254.12.34")
|
||||
check not isPublic("169.254.12.34")
|
||||
check isPrivate("172.31.200.8")
|
||||
check not isPublic("172.31.200.8")
|
||||
check not isPrivate("1.1.1.1")
|
||||
check isPublic("1.1.1.1")
|
||||
check not isPrivate("185.199.108.153")
|
||||
check isPublic("185.199.108.153")
|
||||
@@ -341,6 +341,20 @@ suite "MultiAddress test suite":
|
||||
MultiAddress.init("/ip4/0.0.0.0").get().protoAddress().get() == address_v4
|
||||
MultiAddress.init("/ip6/::0").get().protoAddress().get() == address_v6
|
||||
|
||||
test "MultiAddress getPart":
|
||||
let ma = MultiAddress
|
||||
.init(
|
||||
"/ip4/0.0.0.0/tcp/0/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/p2p-circuit/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSuNEXT/unix/stdio/"
|
||||
)
|
||||
.get()
|
||||
check:
|
||||
$ma.getPart(multiCodec("ip4")).get() == "/ip4/0.0.0.0"
|
||||
$ma.getPart(multiCodec("tcp")).get() == "/tcp/0"
|
||||
# returns first codec match
|
||||
$ma.getPart(multiCodec("p2p")).get() ==
|
||||
"/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC"
|
||||
ma.getPart(multiCodec("udp")).isErr()
|
||||
|
||||
test "MultiAddress getParts":
|
||||
let ma = MultiAddress
|
||||
.init(
|
||||
@@ -421,3 +435,22 @@ suite "MultiAddress test suite":
|
||||
for item in CrashesVectors:
|
||||
let res = MultiAddress.init(hexToSeqByte(item))
|
||||
check res.isErr()
|
||||
|
||||
test "areAddrsConsistent":
|
||||
# same address should be consistent
|
||||
check areAddrsConsistent(
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
|
||||
)
|
||||
|
||||
# different addresses with same stack should be consistent
|
||||
check areAddrsConsistent(
|
||||
MultiAddress.init("/ip4/127.0.0.2/tcp/4041").get(),
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
|
||||
)
|
||||
|
||||
# addresses with different stacks should not be consistent
|
||||
check not areAddrsConsistent(
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/4040").get(),
|
||||
MultiAddress.init("/ip4/127.0.0.1/udp/4040").get(),
|
||||
)
|
||||
|
||||
@@ -30,11 +30,14 @@ import
|
||||
import
|
||||
testnameresolve, testmultistream, testbufferstream, testidentify,
|
||||
testobservedaddrmanager, testconnmngr, testswitch, testnoise, testpeerinfo,
|
||||
testpeerstore, testping, testmplex, testrelayv1, testrelayv2, testrendezvous,
|
||||
testdiscovery, testyamux, testautonat, testautonatservice, testautorelay, testdcutr,
|
||||
testhpservice, testutility, testhelpers, testwildcardresolverservice, testperf
|
||||
testpeerstore, testping, testmplex, testrelayv1, testrelayv2, testyamux,
|
||||
testyamuxheader, testautonat, testautonatservice, testautonatv2, testautorelay,
|
||||
testdcutr, testhpservice, testutility, testhelpers, testwildcardresolverservice,
|
||||
testperf
|
||||
|
||||
import kademlia/[testencoding, testroutingtable, testfindnode]
|
||||
import discovery/testdiscovery
|
||||
|
||||
import kademlia/[testencoding, testroutingtable, testfindnode, testputval]
|
||||
|
||||
when defined(libp2p_autotls_support):
|
||||
import testautotls
|
||||
|
||||
@@ -15,27 +15,36 @@ import
|
||||
import ./helpers
|
||||
|
||||
proc createServerAcceptConn(
|
||||
server: QuicTransport
|
||||
server: QuicTransport, isEofExpected: bool = false
|
||||
): proc(): Future[void] {.
|
||||
async: (raises: [transport.TransportError, LPStreamError, CancelledError])
|
||||
.} =
|
||||
proc handler() {.
|
||||
async: (raises: [transport.TransportError, LPStreamError, CancelledError])
|
||||
.} =
|
||||
try:
|
||||
let conn = await server.accept()
|
||||
while true:
|
||||
let conn =
|
||||
try:
|
||||
await server.accept()
|
||||
except QuicTransportAcceptStopped:
|
||||
return # Transport is stopped
|
||||
if conn == nil:
|
||||
return
|
||||
continue
|
||||
|
||||
let stream = await getStream(QuicSession(conn), Direction.In)
|
||||
var resp: array[6, byte]
|
||||
await stream.readExactly(addr resp, 6)
|
||||
check string.fromBytes(resp) == "client"
|
||||
defer:
|
||||
await stream.close()
|
||||
|
||||
await stream.write("server")
|
||||
await stream.close()
|
||||
except QuicTransportAcceptStopped:
|
||||
discard # Transport is stopped
|
||||
try:
|
||||
var resp: array[6, byte]
|
||||
await stream.readExactly(addr resp, 6)
|
||||
check string.fromBytes(resp) == "client"
|
||||
await stream.write("server")
|
||||
except LPStreamEOFError as exc:
|
||||
if isEofExpected:
|
||||
discard
|
||||
else:
|
||||
raise exc
|
||||
|
||||
return handler
|
||||
|
||||
@@ -119,9 +128,27 @@ suite "Quic transport":
|
||||
|
||||
await runClient()
|
||||
|
||||
asyncTest "server not accepting":
|
||||
let server = await createTransport()
|
||||
# itentionally not calling createServerAcceptConn as server should not accept
|
||||
defer:
|
||||
await server.stop()
|
||||
|
||||
proc runClient() {.async.} =
|
||||
# client should be able to write even when server has not accepted
|
||||
let client = await createTransport()
|
||||
let conn = await client.dial("", server.addrs[0])
|
||||
let stream = await getStream(QuicSession(conn), Direction.Out)
|
||||
await stream.write("client")
|
||||
await client.stop()
|
||||
|
||||
await runClient()
|
||||
|
||||
asyncTest "closing session should close all streams":
|
||||
let server = await createTransport()
|
||||
asyncSpawn createServerAcceptConn(server)()
|
||||
# because some clients will not write full message,
|
||||
# it is expected for server to receive eof
|
||||
asyncSpawn createServerAcceptConn(server, true)()
|
||||
defer:
|
||||
await server.stop()
|
||||
|
||||
@@ -147,3 +174,36 @@ suite "Quic transport":
|
||||
|
||||
# run multiple clients simultainiously
|
||||
await allFutures(runClient(), runClient(), runClient())
|
||||
|
||||
asyncTest "read/write Lp":
|
||||
proc serverHandler(
|
||||
server: QuicTransport
|
||||
) {.async: (raises: [transport.TransportError, LPStreamError, CancelledError]).} =
|
||||
while true:
|
||||
let conn =
|
||||
try:
|
||||
await server.accept()
|
||||
except QuicTransportAcceptStopped:
|
||||
return # Transport is stopped
|
||||
if conn == nil:
|
||||
continue
|
||||
|
||||
let stream = await getStream(QuicSession(conn), Direction.In)
|
||||
check (await stream.readLp(100)) == fromHex("1234")
|
||||
await stream.writeLp(fromHex("5678"))
|
||||
await stream.close()
|
||||
|
||||
proc runClient(server: QuicTransport) {.async.} =
|
||||
let client = await createTransport()
|
||||
let conn = await client.dial("", server.addrs[0])
|
||||
let stream = await getStream(QuicSession(conn), Direction.Out)
|
||||
await stream.writeLp(fromHex("1234"))
|
||||
check (await stream.readLp(100)) == fromHex("5678")
|
||||
await client.stop()
|
||||
|
||||
let server = await createTransport()
|
||||
asyncSpawn serverHandler(server)
|
||||
defer:
|
||||
await server.stop()
|
||||
|
||||
await runClient(server)
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import sequtils, strutils
|
||||
import chronos
|
||||
import ../libp2p/[protocols/rendezvous, switch, builders]
|
||||
import ../libp2p/discovery/discoverymngr
|
||||
import ./helpers
|
||||
|
||||
proc createSwitch(rdv: RendezVous = RendezVous.new()): Switch =
|
||||
SwitchBuilder
|
||||
.new()
|
||||
.withRng(newRng())
|
||||
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
|
||||
.withTcpTransport()
|
||||
.withMplex()
|
||||
.withNoise()
|
||||
.withRendezVous(rdv)
|
||||
.build()
|
||||
|
||||
suite "RendezVous":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
asyncTest "Simple local test":
|
||||
let
|
||||
rdv = RendezVous.new()
|
||||
s = createSwitch(rdv)
|
||||
|
||||
await s.start()
|
||||
let res0 = rdv.requestLocally("empty")
|
||||
check res0.len == 0
|
||||
await rdv.advertise("foo")
|
||||
let res1 = rdv.requestLocally("foo")
|
||||
check:
|
||||
res1.len == 1
|
||||
res1[0] == s.peerInfo.signedPeerRecord.data
|
||||
let res2 = rdv.requestLocally("bar")
|
||||
check res2.len == 0
|
||||
rdv.unsubscribeLocally("foo")
|
||||
let res3 = rdv.requestLocally("foo")
|
||||
check res3.len == 0
|
||||
await s.stop()
|
||||
|
||||
asyncTest "Simple remote test":
|
||||
let
|
||||
rdv = RendezVous.new()
|
||||
client = createSwitch(rdv)
|
||||
remoteSwitch = createSwitch()
|
||||
|
||||
await client.start()
|
||||
await remoteSwitch.start()
|
||||
await client.connect(remoteSwitch.peerInfo.peerId, remoteSwitch.peerInfo.addrs)
|
||||
let res0 = await rdv.request(Opt.some("empty"))
|
||||
check res0.len == 0
|
||||
|
||||
await rdv.advertise("foo")
|
||||
let res1 = await rdv.request(Opt.some("foo"))
|
||||
check:
|
||||
res1.len == 1
|
||||
res1[0] == client.peerInfo.signedPeerRecord.data
|
||||
|
||||
let res2 = await rdv.request(Opt.some("bar"))
|
||||
check res2.len == 0
|
||||
|
||||
await rdv.unsubscribe("foo")
|
||||
let res3 = await rdv.request(Opt.some("foo"))
|
||||
check res3.len == 0
|
||||
|
||||
await allFutures(client.stop(), remoteSwitch.stop())
|
||||
|
||||
asyncTest "Harder remote test":
|
||||
var
|
||||
rdvSeq: seq[RendezVous] = @[]
|
||||
clientSeq: seq[Switch] = @[]
|
||||
remoteSwitch = createSwitch()
|
||||
|
||||
for x in 0 .. 10:
|
||||
rdvSeq.add(RendezVous.new())
|
||||
clientSeq.add(createSwitch(rdvSeq[^1]))
|
||||
await remoteSwitch.start()
|
||||
await allFutures(clientSeq.mapIt(it.start()))
|
||||
await allFutures(
|
||||
clientSeq.mapIt(remoteSwitch.connect(it.peerInfo.peerId, it.peerInfo.addrs))
|
||||
)
|
||||
await allFutures(rdvSeq.mapIt(it.advertise("foo")))
|
||||
var data = clientSeq.mapIt(it.peerInfo.signedPeerRecord.data)
|
||||
let res1 = await rdvSeq[0].request(Opt.some("foo"), 5)
|
||||
check res1.len == 5
|
||||
for d in res1:
|
||||
check d in data
|
||||
data.keepItIf(it notin res1)
|
||||
let res2 = await rdvSeq[0].request(Opt.some("foo"))
|
||||
check res2.len == 5
|
||||
for d in res2:
|
||||
check d in data
|
||||
let res3 = await rdvSeq[0].request(Opt.some("foo"))
|
||||
check res3.len == 0
|
||||
let res4 = await rdvSeq[0].request()
|
||||
check res4.len == 11
|
||||
let res5 = await rdvSeq[0].request(Opt.none(string))
|
||||
check res5.len == 11
|
||||
await remoteSwitch.stop()
|
||||
await allFutures(clientSeq.mapIt(it.stop()))
|
||||
|
||||
asyncTest "Simple cookie test":
|
||||
let
|
||||
rdvA = RendezVous.new()
|
||||
rdvB = RendezVous.new()
|
||||
clientA = createSwitch(rdvA)
|
||||
clientB = createSwitch(rdvB)
|
||||
remoteSwitch = createSwitch()
|
||||
|
||||
await clientA.start()
|
||||
await clientB.start()
|
||||
await remoteSwitch.start()
|
||||
await clientA.connect(remoteSwitch.peerInfo.peerId, remoteSwitch.peerInfo.addrs)
|
||||
await clientB.connect(remoteSwitch.peerInfo.peerId, remoteSwitch.peerInfo.addrs)
|
||||
await rdvA.advertise("foo")
|
||||
let res1 = await rdvA.request(Opt.some("foo"))
|
||||
await rdvB.advertise("foo")
|
||||
let res2 = await rdvA.request(Opt.some("foo"))
|
||||
check:
|
||||
res2.len == 1
|
||||
res2[0] == clientB.peerInfo.signedPeerRecord.data
|
||||
await allFutures(clientA.stop(), clientB.stop(), remoteSwitch.stop())
|
||||
|
||||
asyncTest "Various local error":
|
||||
let
|
||||
rdv = RendezVous.new(minDuration = 1.minutes, maxDuration = 72.hours)
|
||||
switch = createSwitch(rdv)
|
||||
expect AdvertiseError:
|
||||
discard await rdv.request(Opt.some("A".repeat(300)))
|
||||
expect AdvertiseError:
|
||||
discard await rdv.request(Opt.some("A"), -1)
|
||||
expect AdvertiseError:
|
||||
discard await rdv.request(Opt.some("A"), 3000)
|
||||
expect AdvertiseError:
|
||||
await rdv.advertise("A".repeat(300))
|
||||
expect AdvertiseError:
|
||||
await rdv.advertise("A", 73.hours)
|
||||
expect AdvertiseError:
|
||||
await rdv.advertise("A", 30.seconds)
|
||||
|
||||
test "Various config error":
|
||||
expect RendezVousError:
|
||||
discard RendezVous.new(minDuration = 30.seconds)
|
||||
expect RendezVousError:
|
||||
discard RendezVous.new(maxDuration = 73.hours)
|
||||
expect RendezVousError:
|
||||
discard RendezVous.new(minDuration = 15.minutes, maxDuration = 10.minutes)
|
||||
@@ -11,7 +11,11 @@
|
||||
|
||||
import sugar
|
||||
import chronos
|
||||
import ../libp2p/[stream/connection, stream/bridgestream, muxers/yamux/yamux], ./helpers
|
||||
import ../libp2p/[stream/connection, stream/bridgestream, muxers/yamux/yamux]
|
||||
import ./helpers
|
||||
import ./utils/futures
|
||||
|
||||
include ../libp2p/muxers/yamux/yamux
|
||||
|
||||
proc newBlockerFut(): Future[void] {.async: (raises: [], raw: true).} =
|
||||
newFuture[void]()
|
||||
@@ -24,6 +28,8 @@ suite "Yamux":
|
||||
ws: int = YamuxDefaultWindowSize,
|
||||
inTo: Duration = 5.minutes,
|
||||
outTo: Duration = 5.minutes,
|
||||
startHandlera = true,
|
||||
startHandlerb = true,
|
||||
) {.inject.} =
|
||||
#TODO in a template to avoid threadvar
|
||||
let
|
||||
@@ -34,7 +40,14 @@ suite "Yamux":
|
||||
Yamux.new(conna, windowSize = ws, inTimeout = inTo, outTimeout = outTo)
|
||||
yamuxb {.inject.} =
|
||||
Yamux.new(connb, windowSize = ws, inTimeout = inTo, outTimeout = outTo)
|
||||
(handlera, handlerb) = (yamuxa.handle(), yamuxb.handle())
|
||||
var
|
||||
handlera = completedFuture()
|
||||
handlerb = completedFuture()
|
||||
|
||||
if startHandlera:
|
||||
handlera = yamuxa.handle()
|
||||
if startHandlerb:
|
||||
handlerb = yamuxb.handle()
|
||||
|
||||
defer:
|
||||
await allFutures(
|
||||
@@ -166,8 +179,9 @@ suite "Yamux":
|
||||
|
||||
let writerBlocker = newBlockerFut()
|
||||
var numberOfRead = 0
|
||||
const newWindow = 20
|
||||
yamuxb.streamHandler = proc(conn: Connection) {.async: (raises: []).} =
|
||||
YamuxChannel(conn).setMaxRecvWindow(20)
|
||||
YamuxChannel(conn).setMaxRecvWindow(newWindow)
|
||||
try:
|
||||
var buffer: array[256000, byte]
|
||||
while (await conn.readOnce(addr buffer[0], 256000)) > 0:
|
||||
@@ -183,13 +197,14 @@ suite "Yamux":
|
||||
|
||||
# Need to exhaust initial window first
|
||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||
await streamA.write(newSeq[byte](142))
|
||||
const extraBytes = 160
|
||||
await streamA.write(newSeq[byte](extraBytes))
|
||||
await streamA.close()
|
||||
|
||||
await writerBlocker
|
||||
|
||||
# 1 for initial exhaustion + (142 / 20) = 9
|
||||
check numberOfRead == 9
|
||||
# 1 for initial exhaustion + (160 / 20) = 9
|
||||
check numberOfRead == 1 + (extraBytes / newWindow).int
|
||||
|
||||
asyncTest "Saturate until reset":
|
||||
mSetup()
|
||||
@@ -412,3 +427,101 @@ suite "Yamux":
|
||||
await streamA.writeLp(fromHex("1234"))
|
||||
await streamA.close()
|
||||
check (await streamA.readLp(100)) == fromHex("5678")
|
||||
|
||||
suite "Frame handling and stream initiation":
|
||||
asyncTest "Ping Syn responds Ping Ack":
|
||||
mSetup(startHandlera = false)
|
||||
|
||||
let payload: uint32 = 0x12345678'u32
|
||||
await conna.write(YamuxHeader.ping(MsgFlags.Syn, payload))
|
||||
|
||||
let header = await conna.readHeader()
|
||||
check:
|
||||
header.msgType == Ping
|
||||
header.flags == {Ack}
|
||||
header.length == payload
|
||||
|
||||
asyncTest "Go Away Status responds with Go Away":
|
||||
mSetup(startHandlera = false)
|
||||
|
||||
await conna.write(YamuxHeader.goAway(GoAwayStatus.ProtocolError))
|
||||
|
||||
let header = await conna.readHeader()
|
||||
check:
|
||||
header.msgType == GoAway
|
||||
header.flags == {}
|
||||
header.length == GoAwayStatus.NormalTermination.uint32
|
||||
|
||||
for testCase in [
|
||||
YamuxHeader.data(streamId = 1'u32, length = 0, {Syn}),
|
||||
YamuxHeader.windowUpdate(streamId = 5'u32, delta = 0, {Syn}),
|
||||
]:
|
||||
asyncTest "Syn opens stream and sends Ack - " & $testCase:
|
||||
mSetup(startHandlera = false)
|
||||
|
||||
yamuxb.streamHandler = proc(conn: Connection) {.async: (raises: []).} =
|
||||
try:
|
||||
await conn.close()
|
||||
except CancelledError, LPStreamError:
|
||||
return
|
||||
|
||||
await conna.write(testCase)
|
||||
|
||||
let ackHeader = await conna.readHeader()
|
||||
check:
|
||||
ackHeader.msgType == WindowUpdate
|
||||
ackHeader.streamId == testCase.streamId
|
||||
ackHeader.flags == {Ack}
|
||||
|
||||
check:
|
||||
yamuxb.channels.hasKey(testCase.streamId)
|
||||
yamuxb.channels[testCase.streamId].opened
|
||||
|
||||
let finHeader = await conna.readHeader()
|
||||
check:
|
||||
finHeader.msgType == Data
|
||||
finHeader.streamId == testCase.streamId
|
||||
finHeader.flags == {Fin}
|
||||
|
||||
for badHeader in [
|
||||
# Reserved parity on Data+Syn (even id against responder)
|
||||
YamuxHeader.data(streamId = 2'u32, length = 0, {Syn}),
|
||||
# Reserved stream id 0
|
||||
YamuxHeader.data(streamId = 0'u32, length = 0, {Syn}),
|
||||
# Reserved parity on WindowUpdate+Syn (even id against responder)
|
||||
YamuxHeader.windowUpdate(streamId = 4'u32, delta = 0, {Syn}),
|
||||
]:
|
||||
asyncTest "Reject invalid/unknown header - " & $badHeader:
|
||||
mSetup(startHandlera = false)
|
||||
|
||||
await conna.write(badHeader)
|
||||
|
||||
let header = await conna.readHeader()
|
||||
check:
|
||||
header.msgType == GoAway
|
||||
header.flags == {}
|
||||
header.length == GoAwayStatus.ProtocolError.uint32
|
||||
not yamuxb.channels.hasKey(badHeader.streamId)
|
||||
|
||||
asyncTest "Flush unknown-stream Data up to budget then ProtocolError when exceeded":
|
||||
# Cover the flush path: streamId not in channels, no Syn, with a pre-seeded
|
||||
# flush budget in yamuxb.flushed. First frame should be flushed (no GoAway),
|
||||
# second frame exceeding the remaining budget should trigger ProtocolError.
|
||||
mSetup(startHandlera = false)
|
||||
|
||||
let streamId = 11'u32
|
||||
yamuxb.flushed[streamId] = 4 # allow up to 4 bytes to be flushed
|
||||
|
||||
# 1) Send a Data frame (no Syn) with length=3 and a 3-byte payload -> should be flushed.
|
||||
await conna.write(YamuxHeader.data(streamId = streamId, length = 3))
|
||||
await conna.write(fromHex("010203"))
|
||||
|
||||
# 2) Send another Data frame with length=2 (remaining budget is 1) -> exceeds, expect GoAway.
|
||||
await conna.write(YamuxHeader.data(streamId = streamId, length = 2))
|
||||
await conna.write(fromHex("0405"))
|
||||
|
||||
let header = await conna.readHeader()
|
||||
check:
|
||||
header.msgType == GoAway
|
||||
header.flags == {}
|
||||
header.length == GoAwayStatus.ProtocolError.uint32
|
||||
|
||||
321
tests/testyamuxheader.nim
Normal file
321
tests/testyamuxheader.nim
Normal file
@@ -0,0 +1,321 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import ../libp2p/stream/[bufferstream, lpstream]
|
||||
import ./helpers
|
||||
|
||||
include ../libp2p/muxers/yamux/yamux
|
||||
|
||||
proc readBytes(bytes: array[12, byte]): Future[YamuxHeader] {.async.} =
|
||||
let bs = BufferStream.new()
|
||||
defer:
|
||||
await bs.close()
|
||||
|
||||
await bs.pushData(@bytes)
|
||||
|
||||
return await readHeader(bs)
|
||||
|
||||
suite "Yamux Header Tests":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
asyncTest "Data header":
|
||||
const
|
||||
streamId = 1
|
||||
length = 100
|
||||
flags = {Syn}
|
||||
let header = YamuxHeader.data(streamId = streamId, length = length, flags)
|
||||
let dataEncoded = header.encode()
|
||||
|
||||
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x length_bytes]
|
||||
const expected = [byte 0, 0, 0, 1, 0, 0, 0, streamId.byte, 0, 0, 0, length.byte]
|
||||
check:
|
||||
dataEncoded == expected
|
||||
|
||||
let headerDecoded = await readBytes(dataEncoded)
|
||||
check:
|
||||
headerDecoded.version == 0
|
||||
headerDecoded.msgType == MsgType.Data
|
||||
headerDecoded.flags == flags
|
||||
headerDecoded.streamId == streamId
|
||||
headerDecoded.length == length
|
||||
|
||||
asyncTest "Window update":
|
||||
const
|
||||
streamId = 5
|
||||
delta = 1000
|
||||
flags = {Syn}
|
||||
let windowUpdateHeader =
|
||||
YamuxHeader.windowUpdate(streamId = streamId, delta = delta, flags)
|
||||
let windowEncoded = windowUpdateHeader.encode()
|
||||
|
||||
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x delta_bytes]
|
||||
# delta == 1000 == 0x03E8
|
||||
const expected = [byte 0, 1, 0, 1, 0, 0, 0, streamId.byte, 0, 0, 0x03, 0xE8]
|
||||
check:
|
||||
windowEncoded == expected
|
||||
|
||||
let windowDecoded = await readBytes(windowEncoded)
|
||||
check:
|
||||
windowDecoded.version == 0
|
||||
windowDecoded.msgType == MsgType.WindowUpdate
|
||||
windowDecoded.flags == flags
|
||||
windowDecoded.streamId == streamId
|
||||
windowDecoded.length == delta
|
||||
|
||||
asyncTest "Ping":
|
||||
let pingHeader = YamuxHeader.ping(MsgFlags.Syn, 0x12345678'u32)
|
||||
let pingEncoded = pingHeader.encode()
|
||||
|
||||
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x value_bytes]
|
||||
const expected = [byte 0, 2, 0, 1, 0, 0, 0, 0, 0x12, 0x34, 0x56, 0x78]
|
||||
check:
|
||||
pingEncoded == expected
|
||||
|
||||
let pingDecoded = await readBytes(pingEncoded)
|
||||
check:
|
||||
pingDecoded.version == 0
|
||||
pingDecoded.msgType == MsgType.Ping
|
||||
pingDecoded.flags == {Syn}
|
||||
pingDecoded.streamId == 0
|
||||
pingDecoded.length == 0x12345678'u32
|
||||
|
||||
asyncTest "Go away":
|
||||
let goAwayHeader = YamuxHeader.goAway(GoAwayStatus.ProtocolError)
|
||||
let goAwayEncoded = goAwayHeader.encode()
|
||||
|
||||
# [version == 0, msgType, flags_high, flags_low, 4x streamId_bytes, 4x error_bytes]
|
||||
const expected = [byte 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
|
||||
check:
|
||||
goAwayEncoded == expected
|
||||
|
||||
let goAwayDecoded = await readBytes(goAwayEncoded)
|
||||
check:
|
||||
goAwayDecoded.version == 0
|
||||
goAwayDecoded.msgType == MsgType.GoAway
|
||||
goAwayDecoded.flags == {}
|
||||
goAwayDecoded.streamId == 0
|
||||
goAwayDecoded.length == 1'u32
|
||||
|
||||
asyncTest "Error codes":
|
||||
let encodedNormal = YamuxHeader.goAway(GoAwayStatus.NormalTermination).encode()
|
||||
let encodedProtocol = YamuxHeader.goAway(GoAwayStatus.ProtocolError).encode()
|
||||
let encodedInternal = YamuxHeader.goAway(GoAwayStatus.InternalError).encode()
|
||||
check:
|
||||
encodedNormal[11] == 0
|
||||
encodedProtocol[11] == 1
|
||||
encodedInternal[11] == 2
|
||||
|
||||
let decodedNormal = await readBytes(encodedNormal)
|
||||
let decodedProtocol = await readBytes(encodedProtocol)
|
||||
let decodedInternal = await readBytes(encodedInternal)
|
||||
check:
|
||||
decodedNormal.msgType == MsgType.GoAway
|
||||
decodedNormal.length == 0'u32
|
||||
decodedProtocol.msgType == MsgType.GoAway
|
||||
decodedProtocol.length == 1'u32
|
||||
decodedInternal.msgType == MsgType.GoAway
|
||||
decodedInternal.length == 2'u32
|
||||
|
||||
asyncTest "Flags":
|
||||
const
|
||||
streamId = 1
|
||||
length = 100
|
||||
let cases: seq[(set[MsgFlags], uint8)] =
|
||||
@[
|
||||
({}, 0'u8),
|
||||
({Syn}, 1'u8),
|
||||
({Ack}, 2'u8),
|
||||
({Syn, Ack}, 3'u8),
|
||||
({Fin}, 4'u8),
|
||||
({Syn, Fin}, 5'u8),
|
||||
({Ack, Fin}, 6'u8),
|
||||
({Syn, Ack, Fin}, 7'u8),
|
||||
({Rst}, 8'u8),
|
||||
({Syn, Rst}, 9'u8),
|
||||
({Ack, Rst}, 10'u8),
|
||||
({Syn, Ack, Rst}, 11'u8),
|
||||
({Fin, Rst}, 12'u8),
|
||||
({Syn, Fin, Rst}, 13'u8),
|
||||
({Ack, Fin, Rst}, 14'u8),
|
||||
({Syn, Ack, Fin, Rst}, 15'u8),
|
||||
]
|
||||
|
||||
for (flags, low) in cases:
|
||||
let header = YamuxHeader.data(streamId = streamId, length = length, flags)
|
||||
let encoded = header.encode()
|
||||
check encoded[2 .. 3] == [byte 0, low]
|
||||
|
||||
let decoded = await readBytes(encoded)
|
||||
check decoded.flags == flags
|
||||
|
||||
asyncTest "Boundary conditions":
|
||||
# Test maximum values
|
||||
const maxFlags = {Syn, Ack, Fin, Rst}
|
||||
let maxHeader =
|
||||
YamuxHeader.data(streamId = uint32.high, length = uint32.high, maxFlags)
|
||||
let maxEncoded = maxHeader.encode()
|
||||
|
||||
const maxExpected = [byte 0, 0, 0, 15, 255, 255, 255, 255, 255, 255, 255, 255]
|
||||
check:
|
||||
maxEncoded == maxExpected
|
||||
|
||||
let maxDecoded = await readBytes(maxEncoded)
|
||||
check:
|
||||
maxDecoded.version == 0
|
||||
maxDecoded.msgType == MsgType.Data
|
||||
maxDecoded.flags == maxFlags
|
||||
maxDecoded.streamId == uint32.high
|
||||
maxDecoded.length == uint32.high
|
||||
|
||||
# Test minimum values
|
||||
let minHeader = YamuxHeader.data(streamId = 0, length = 0, {})
|
||||
let minEncoded = minHeader.encode()
|
||||
|
||||
const minExpected = [byte 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
check:
|
||||
minEncoded == minExpected
|
||||
|
||||
let minDecoded = await readBytes(minEncoded)
|
||||
check:
|
||||
minDecoded.version == 0
|
||||
minDecoded.msgType == MsgType.Data
|
||||
minDecoded.flags == {}
|
||||
minDecoded.streamId == 0
|
||||
minDecoded.length == 0'u32
|
||||
|
||||
asyncTest "Incomplete header should raise LPStreamIncompleteError":
|
||||
let buff = BufferStream.new()
|
||||
|
||||
let valid = YamuxHeader.data(streamId = 7, length = 0, {}).encode()
|
||||
# Supply only first 10 bytes (<12)
|
||||
let partial: seq[byte] = @valid[0 .. 9]
|
||||
await buff.pushData(partial)
|
||||
|
||||
# Start the read first so close() can propagate EOF to it
|
||||
let headerFut = readHeader(buff)
|
||||
await buff.close()
|
||||
|
||||
expect LPStreamIncompleteError:
|
||||
discard await headerFut
|
||||
|
||||
asyncTest "Non-zero version byte is preserved":
|
||||
let valid = YamuxHeader.data(streamId = 1, length = 100, {Syn}).encode()
|
||||
var mutated = valid
|
||||
mutated[0] = 1'u8
|
||||
|
||||
let decoded = await readBytes(mutated)
|
||||
check:
|
||||
decoded.version == 1
|
||||
|
||||
asyncTest "Invalid msgType should raise YamuxError":
|
||||
let valid = YamuxHeader.data(streamId = 1, length = 0, {}).encode()
|
||||
var mutated = valid
|
||||
mutated[1] = 0xFF'u8
|
||||
|
||||
expect YamuxError:
|
||||
discard await readBytes(mutated)
|
||||
|
||||
asyncTest "Invalid flags should raise YamuxError":
|
||||
let valid = YamuxHeader.data(streamId = 1, length = 0, {}).encode()
|
||||
var mutated = valid
|
||||
# Set flags to 16 which is outside the allowed 0..15 range
|
||||
mutated[2] = 0'u8
|
||||
mutated[3] = 16'u8
|
||||
|
||||
expect YamuxError:
|
||||
discard await readBytes(mutated)
|
||||
|
||||
asyncTest "Invalid flags (high byte non-zero) should raise YamuxError":
|
||||
let valid = YamuxHeader.data(streamId = 1, length = 0, {}).encode()
|
||||
var mutated = valid
|
||||
# Set high flags byte to 1, which is outside the allowed 0..15 range
|
||||
mutated[2] = 1'u8
|
||||
mutated[3] = 0'u8
|
||||
|
||||
expect YamuxError:
|
||||
discard await readBytes(mutated)
|
||||
|
||||
asyncTest "Partial push (6+6 bytes) completes without closing":
|
||||
const
|
||||
streamId = 9
|
||||
length = 42
|
||||
flags = {Syn}
|
||||
|
||||
# Prepare a valid header
|
||||
let header = YamuxHeader.data(streamId = streamId, length = length, flags)
|
||||
let bytes = header.encode()
|
||||
|
||||
let buff = BufferStream.new()
|
||||
defer:
|
||||
await buff.close()
|
||||
|
||||
# Push first half (6 bytes)
|
||||
let first: seq[byte] = @bytes[0 .. 5]
|
||||
await buff.pushData(first)
|
||||
|
||||
# Start read and then push the remaining bytes
|
||||
let headerFut = readHeader(buff)
|
||||
let second: seq[byte] = @bytes[6 .. 11]
|
||||
await buff.pushData(second)
|
||||
|
||||
let decoded = await headerFut
|
||||
check:
|
||||
decoded.version == 0
|
||||
decoded.msgType == MsgType.Data
|
||||
decoded.flags == flags
|
||||
decoded.streamId == streamId
|
||||
decoded.length == length
|
||||
|
||||
asyncTest "Two headers back-to-back decode sequentially":
|
||||
let h1 = YamuxHeader.data(streamId = 2, length = 10, {Ack})
|
||||
let h2 = YamuxHeader.ping(MsgFlags.Syn, 0xABCDEF01'u32)
|
||||
let b1 = h1.encode()
|
||||
let b2 = h2.encode()
|
||||
|
||||
let buff = BufferStream.new()
|
||||
defer:
|
||||
await buff.close()
|
||||
|
||||
await buff.pushData(@b1 & @b2)
|
||||
|
||||
let d1 = await readHeader(buff)
|
||||
let d2 = await readHeader(buff)
|
||||
|
||||
check:
|
||||
d1.msgType == MsgType.Data
|
||||
d1.streamId == 2
|
||||
d1.length == 10
|
||||
d1.flags == {Ack}
|
||||
d2.msgType == MsgType.Ping
|
||||
d2.streamId == 0
|
||||
d2.length == 0xABCDEF01'u32
|
||||
d2.flags == {Syn}
|
||||
|
||||
asyncTest "StreamId 0x01020304 encodes big-endian":
|
||||
const streamId = 0x01020304'u32
|
||||
let header = YamuxHeader.data(streamId = streamId, length = 0, {})
|
||||
let enc = header.encode()
|
||||
check enc[4 .. 7] == [byte 1, 2, 3, 4]
|
||||
|
||||
let dec = await readBytes(enc)
|
||||
check dec.streamId == streamId
|
||||
|
||||
asyncTest "GoAway unknown status code is preserved":
|
||||
let valid = YamuxHeader.goAway(GoAwayStatus.NormalTermination).encode()
|
||||
var mutated = valid
|
||||
# Set the GoAway code (last byte) to 255, which is not a known GoAwayStatus
|
||||
mutated[11] = 255'u8
|
||||
|
||||
let decoded = await readBytes(mutated)
|
||||
check:
|
||||
decoded.msgType == MsgType.GoAway
|
||||
decoded.length == 255'u32
|
||||
@@ -59,3 +59,8 @@ proc waitForStates*[T](
|
||||
): Future[seq[FutureStateWrapper[T]]] {.async.} =
|
||||
await sleepAsync(timeout)
|
||||
return futures.mapIt(it.toState())
|
||||
|
||||
proc completedFuture*(): Future[void] =
|
||||
let f = newFuture[void]()
|
||||
f.complete()
|
||||
f
|
||||
|
||||
Reference in New Issue
Block a user