mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2026-01-10 12:58:05 -05:00
Compare commits
87 Commits
depr-updat
...
async-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
205a17d8a0 | ||
|
|
7d9a2cef69 | ||
|
|
4fbf59ece8 | ||
|
|
62388a7a20 | ||
|
|
27051164db | ||
|
|
f41009461b | ||
|
|
c3faabf522 | ||
|
|
10f7f5c68a | ||
|
|
f345026900 | ||
|
|
5d6578a06f | ||
|
|
871a5d047f | ||
|
|
061195195b | ||
|
|
8add5aaaab | ||
|
|
dbf60b74c7 | ||
|
|
d2eaf07960 | ||
|
|
6e5274487e | ||
|
|
7ed62461d7 | ||
|
|
6059ee8332 | ||
|
|
4f7e232a9e | ||
|
|
5eaa43b860 | ||
|
|
17ed2d88df | ||
|
|
c7f29ed5db | ||
|
|
9865cc39b5 | ||
|
|
601f56b786 | ||
|
|
25a8ed4d07 | ||
|
|
955e28ff70 | ||
|
|
f952e6d436 | ||
|
|
bed83880bf | ||
|
|
9bd4b7393f | ||
|
|
12d1fae404 | ||
|
|
17073dc9e0 | ||
|
|
b1649b3566 | ||
|
|
ef20f46b47 | ||
|
|
9161529c84 | ||
|
|
8b70384b6a | ||
|
|
f25814a890 | ||
|
|
3d5ea1fa3c | ||
|
|
2114008704 | ||
|
|
04796b210b | ||
|
|
59faa023aa | ||
|
|
fdebea4e14 | ||
|
|
0c188df806 | ||
|
|
abee5326dc | ||
|
|
71f04d1bb3 | ||
|
|
41ae43ae80 | ||
|
|
5dbf077d9e | ||
|
|
b5fc7582ff | ||
|
|
7f83ebb198 | ||
|
|
ceb89986c1 | ||
|
|
f4ff27ca6b | ||
|
|
b517b692df | ||
|
|
7cfd26035a | ||
|
|
cd5fea53e3 | ||
|
|
d9aa393761 | ||
|
|
a4a0d9e375 | ||
|
|
c8b406d6ed | ||
|
|
f0125a62df | ||
|
|
9bf2636186 | ||
|
|
01a33ebe5c | ||
|
|
c1cd31079b | ||
|
|
9f9f38e314 | ||
|
|
f83638eb82 | ||
|
|
882cb5dfe3 | ||
|
|
81310df2a2 | ||
|
|
34110a37d7 | ||
|
|
1035e4f314 | ||
|
|
d08bad5893 | ||
|
|
7bdba4909f | ||
|
|
e71c7caf82 | ||
|
|
45476bdd6b | ||
|
|
c7ee7b950d | ||
|
|
87b3d2c864 | ||
|
|
19b4c20e2f | ||
|
|
514bd4b5f5 | ||
|
|
46d936b80c | ||
|
|
80bf27c6bb | ||
|
|
6576c5c3bf | ||
|
|
2e6b1d2738 | ||
|
|
9e6c4cb4d2 | ||
|
|
5f256049ab | ||
|
|
e29ca73386 | ||
|
|
577809750a | ||
|
|
46a5430cc2 | ||
|
|
d8b9f59c5e | ||
|
|
2951356c9d | ||
|
|
7ae21d0cbd | ||
|
|
eee8341ad2 |
34
.github/actions/add_comment/action.yml
vendored
Normal file
34
.github/actions/add_comment/action.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Add Comment
|
||||
description: "Add or update comment in the PR"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Add/Update Comment
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const marker = "${{ env.MARKER }}";
|
||||
const body = fs.readFileSync("${{ env.COMMENT_SUMMARY_PATH }}", 'utf8');
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
});
|
||||
const existing = comments.find(c => c.body && c.body.startsWith(marker));
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existing.id,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body,
|
||||
});
|
||||
}
|
||||
49
.github/actions/discord_notify/action.yml
vendored
Normal file
49
.github/actions/discord_notify/action.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Discord Failure Notification
|
||||
description: "Send Discord notification when CI jobs fail"
|
||||
inputs:
|
||||
webhook_url:
|
||||
description: "Discord webhook URL"
|
||||
required: true
|
||||
workflow_name:
|
||||
description: "Name of the workflow that failed"
|
||||
required: false
|
||||
default: ${{ github.workflow }}
|
||||
branch:
|
||||
description: "Branch name"
|
||||
required: false
|
||||
default: ${{ github.ref_name }}
|
||||
repository:
|
||||
description: "Repository name"
|
||||
required: false
|
||||
default: ${{ github.repository }}
|
||||
run_id:
|
||||
description: "GitHub run ID"
|
||||
required: false
|
||||
default: ${{ github.run_id }}
|
||||
server_url:
|
||||
description: "GitHub server URL"
|
||||
required: false
|
||||
default: ${{ github.server_url }}
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Send Discord notification
|
||||
shell: bash
|
||||
run: |
|
||||
curl -H "Content-Type: application/json" \
|
||||
-X POST \
|
||||
-d "{
|
||||
\"embeds\": [{
|
||||
\"title\": \"${{ inputs.workflow_name }} Job Failed\",
|
||||
\"url\": \"${{ inputs.server_url }}/${{ inputs.repository }}/actions/runs/${{ inputs.run_id }}\",
|
||||
\"description\": \"The workflow has failed on branch \`${{ inputs.branch }}\`\",
|
||||
\"color\": 15158332,
|
||||
\"fields\": [
|
||||
{\"name\": \"Repository\", \"value\": \"${{ inputs.repository }}\", \"inline\": true},
|
||||
{\"name\": \"Branch\", \"value\": \"${{ inputs.branch }}\", \"inline\": true}
|
||||
],
|
||||
\"timestamp\": \"$(date -u +%Y-%m-%dT%H:%M:%S.000Z)\"
|
||||
}]
|
||||
}" \
|
||||
"${{ inputs.webhook_url }}"
|
||||
24
.github/actions/generate_plots/action.yml
vendored
Normal file
24
.github/actions/generate_plots/action.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Generate Plots
|
||||
description: "Set up Python and run script to generate plots with Docker Stats"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install matplotlib
|
||||
|
||||
- name: Plot Docker Stats
|
||||
shell: bash
|
||||
run: python performance/scripts/plot_docker_stats.py
|
||||
|
||||
- name: Plot Latency History
|
||||
shell: bash
|
||||
run: python performance/scripts/plot_latency_history.py
|
||||
2
.github/actions/install_nim/action.yml
vendored
2
.github/actions/install_nim/action.yml
vendored
@@ -8,7 +8,7 @@ inputs:
|
||||
default: "amd64"
|
||||
nim_ref:
|
||||
description: "Nim version"
|
||||
default: "version-1-6"
|
||||
default: "version-2-0"
|
||||
shell:
|
||||
description: "Shell to run commands in"
|
||||
default: "bash --noprofile --norc -e -o pipefail"
|
||||
|
||||
21
.github/actions/process_stats/action.yml
vendored
Normal file
21
.github/actions/process_stats/action.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Process Stats
|
||||
description: "Set up Nim and run scripts to aggregate latency and process raw docker stats"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Nim
|
||||
uses: jiro4989/setup-nim-action@v2
|
||||
with:
|
||||
nim-version: "2.x"
|
||||
repo-token: ${{ env.GITHUB_TOKEN }}
|
||||
|
||||
- name: Aggregate latency stats and prepare markdown for comment and summary
|
||||
shell: bash
|
||||
run: |
|
||||
nim c -r -d:release -o:/tmp/process_latency_stats ./performance/scripts/process_latency_stats.nim
|
||||
|
||||
- name: Process raw docker stats to csv files
|
||||
shell: bash
|
||||
run: |
|
||||
nim c -r -d:release -o:/tmp/process_docker_stats ./performance/scripts/process_docker_stats.nim
|
||||
36
.github/actions/publish_history/action.yml
vendored
Normal file
36
.github/actions/publish_history/action.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Publish Latency History
|
||||
description: "Publish latency history CSVs in a configurable branch and folder"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Clone the branch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: ${{ github.repository }}
|
||||
ref: ${{ env.PUBLISH_BRANCH_NAME }}
|
||||
path: ${{ env.CHECKOUT_SUBFOLDER_HISTORY }}
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Commit & push latency history CSVs
|
||||
shell: bash
|
||||
run: |
|
||||
cd "$CHECKOUT_SUBFOLDER_HISTORY"
|
||||
git fetch origin "$PUBLISH_BRANCH_NAME"
|
||||
git reset --hard "origin/$PUBLISH_BRANCH_NAME"
|
||||
|
||||
mkdir -p "$PUBLISH_DIR_LATENCY_HISTORY"
|
||||
|
||||
cp ../$SHARED_VOLUME_PATH/$LATENCY_HISTORY_PREFIX*.csv "$PUBLISH_DIR_LATENCY_HISTORY/"
|
||||
git add "$PUBLISH_DIR_LATENCY_HISTORY"
|
||||
|
||||
if git diff-index --quiet HEAD --; then
|
||||
echo "No changes to commit"
|
||||
else
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config user.name "github-actions[bot]"
|
||||
git commit -m "Update latency history CSVs"
|
||||
git push origin "$PUBLISH_BRANCH_NAME"
|
||||
fi
|
||||
|
||||
cd ..
|
||||
56
.github/actions/publish_plots/action.yml
vendored
Normal file
56
.github/actions/publish_plots/action.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Publish Plots
|
||||
description: "Publish plots in performance_plots branch and add to the workflow summary"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Clone the performance_plots branch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: ${{ github.repository }}
|
||||
ref: ${{ env.PUBLISH_BRANCH_NAME }}
|
||||
path: ${{ env.CHECKOUT_SUBFOLDER_SUBPLOTS }}
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Commit & push plots
|
||||
shell: bash
|
||||
run: |
|
||||
cd $CHECKOUT_SUBFOLDER_SUBPLOTS
|
||||
git fetch origin "$PUBLISH_BRANCH_NAME"
|
||||
git reset --hard "origin/$PUBLISH_BRANCH_NAME"
|
||||
|
||||
# Remove any branch folder older than 7 days
|
||||
DAYS=7
|
||||
cutoff=$(( $(date +%s) - DAYS*24*3600 ))
|
||||
scan_dir="${PUBLISH_DIR_PLOTS%/}"
|
||||
find "$scan_dir" -mindepth 1 -maxdepth 1 -type d -print0 \
|
||||
| while IFS= read -r -d $'\0' d; do \
|
||||
ts=$(git log -1 --format=%ct -- "$d" 2>/dev/null || true); \
|
||||
if [ -n "$ts" ] && [ "$ts" -le "$cutoff" ]; then \
|
||||
echo "[cleanup] Deleting: $d"; \
|
||||
rm -rf -- "$d"; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
rm -rf $PUBLISH_DIR_PLOTS/$BRANCH_NAME
|
||||
mkdir -p $PUBLISH_DIR_PLOTS/$BRANCH_NAME
|
||||
|
||||
cp ../$SHARED_VOLUME_PATH/*.png $PUBLISH_DIR_PLOTS/$BRANCH_NAME/ 2>/dev/null || true
|
||||
cp ../$LATENCY_HISTORY_PATH/*.png $PUBLISH_DIR_PLOTS/ 2>/dev/null || true
|
||||
git add -A "$PUBLISH_DIR_PLOTS/"
|
||||
|
||||
git status
|
||||
|
||||
if git diff-index --quiet HEAD --; then
|
||||
echo "No changes to commit"
|
||||
else
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config user.name "github-actions[bot]"
|
||||
git commit -m "Update performance plots for $BRANCH_NAME"
|
||||
git push origin $PUBLISH_BRANCH_NAME
|
||||
fi
|
||||
|
||||
- name: Add plots to GitHub Actions summary
|
||||
shell: bash
|
||||
run: |
|
||||
nim c -r -d:release -o:/tmp/add_plots_to_summary ./performance/scripts/add_plots_to_summary.nim
|
||||
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -25,15 +25,11 @@ jobs:
|
||||
cpu: i386
|
||||
- os: linux-gcc-14
|
||||
cpu: amd64
|
||||
- os: macos
|
||||
cpu: amd64
|
||||
- os: macos-14
|
||||
cpu: arm64
|
||||
- os: windows
|
||||
cpu: amd64
|
||||
nim:
|
||||
- ref: version-1-6
|
||||
memory_management: refc
|
||||
- ref: version-2-0
|
||||
memory_management: refc
|
||||
- ref: version-2-2
|
||||
@@ -47,10 +43,6 @@ jobs:
|
||||
os: linux-gcc-14
|
||||
builder: ubuntu-24.04
|
||||
shell: bash
|
||||
- platform:
|
||||
os: macos
|
||||
builder: macos-13
|
||||
shell: bash
|
||||
- platform:
|
||||
os: macos-14
|
||||
builder: macos-14
|
||||
|
||||
25
.github/workflows/daily_amd64.yml
vendored
25
.github/workflows/daily_amd64.yml
vendored
@@ -7,25 +7,36 @@ on:
|
||||
|
||||
jobs:
|
||||
test_amd64_latest:
|
||||
name: Daily amd64 (latest dependencies)
|
||||
name: Daily test amd64 (latest dependencies)
|
||||
uses: ./.github/workflows/daily_common.yml
|
||||
with:
|
||||
nim: "[
|
||||
{'ref': 'version-1-6', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-0', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-0', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-2', 'memory_management': 'refc'},
|
||||
{'ref': 'devel', 'memory_management': 'refc'},
|
||||
]"
|
||||
cpu: "['amd64']"
|
||||
test_amd64_pinned:
|
||||
name: Daily amd64 (pinned dependencies)
|
||||
name: Daily test amd64 (pinned dependencies)
|
||||
uses: ./.github/workflows/daily_common.yml
|
||||
with:
|
||||
pinned_deps: true
|
||||
nim: "[
|
||||
{'ref': 'version-1-6', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-0', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-0', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-2', 'memory_management': 'refc'},
|
||||
{'ref': 'devel', 'memory_management': 'refc'},
|
||||
]"
|
||||
cpu: "['amd64']"
|
||||
cpu: "['amd64']"
|
||||
notify-on-failure:
|
||||
name: Notify Discord on Failure
|
||||
needs: [test_amd64_latest, test_amd64_pinned]
|
||||
if: failure()
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discord notification
|
||||
uses: ./.github/actions/discord_notify
|
||||
with:
|
||||
webhook_url: ${{ secrets.DISCORD_WEBHOOK_URL }}
|
||||
35
.github/workflows/daily_i386.yml
vendored
35
.github/workflows/daily_i386.yml
vendored
@@ -6,18 +6,45 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test_i386:
|
||||
name: Daily i386 (Linux)
|
||||
test_i386_latest:
|
||||
name: Daily i386 (latest dependencies)
|
||||
uses: ./.github/workflows/daily_common.yml
|
||||
with:
|
||||
nim: "[
|
||||
{'ref': 'version-1-6', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-0', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-2', 'memory_management': 'refc'},
|
||||
{'ref': 'devel', 'memory_management': 'refc'},
|
||||
]"
|
||||
cpu: "['i386']"
|
||||
exclude: "[
|
||||
{'platform': {'os':'macos'}},
|
||||
{'platform': {'os':'macos'}},
|
||||
{'platform': {'os':'windows'}},
|
||||
]"
|
||||
test_i386_pinned:
|
||||
name: Daily i386 (pinned dependencies)
|
||||
uses: ./.github/workflows/daily_common.yml
|
||||
with:
|
||||
pinned_deps: true
|
||||
nim: "[
|
||||
{'ref': 'version-2-0', 'memory_management': 'refc'},
|
||||
{'ref': 'version-2-2', 'memory_management': 'refc'},
|
||||
{'ref': 'devel', 'memory_management': 'refc'},
|
||||
]"
|
||||
cpu: "['i386']"
|
||||
exclude: "[
|
||||
{'platform': {'os':'macos'}},
|
||||
{'platform': {'os':'windows'}},
|
||||
]"
|
||||
notify-on-failure:
|
||||
name: Notify Discord on Failure
|
||||
needs: [test_i386_latest, test_i386_pinned]
|
||||
if: failure()
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discord notification
|
||||
uses: ./.github/actions/discord_notify
|
||||
with:
|
||||
webhook_url: ${{ secrets.DISCORD_WEBHOOK_URL }}
|
||||
39
.github/workflows/daily_nimbus.yml
vendored
Normal file
39
.github/workflows/daily_nimbus.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: Daily Nimbus
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "30 6 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
compile_nimbus:
|
||||
timeout-minutes: 80
|
||||
name: 'Compile Nimbus (linux-amd64)'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Compile nimbus using nim-libp2p
|
||||
run: |
|
||||
git clone --branch unstable --single-branch https://github.com/status-im/nimbus-eth2.git
|
||||
cd nimbus-eth2
|
||||
git submodule set-branch --branch ${{ github.sha }} vendor/nim-libp2p
|
||||
|
||||
make -j"$(nproc)"
|
||||
make -j"$(nproc)" nimbus_beacon_node
|
||||
|
||||
notify-on-failure:
|
||||
name: Notify Discord on Failure
|
||||
needs: compile_nimbus
|
||||
if: failure()
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discord notification
|
||||
uses: ./.github/actions/discord_notify
|
||||
with:
|
||||
webhook_url: ${{ secrets.DISCORD_WEBHOOK_URL }}
|
||||
|
||||
12
.github/workflows/dependencies.yml
vendored
12
.github/workflows/dependencies.yml
vendored
@@ -50,4 +50,16 @@ jobs:
|
||||
git branch -D nim-libp2p-auto-bump-${{ matrix.target.ref }} || true
|
||||
git switch -c nim-libp2p-auto-bump-${{ matrix.target.ref }}
|
||||
git push -f origin nim-libp2p-auto-bump-${{ matrix.target.ref }}
|
||||
notify-on-failure:
|
||||
name: Notify Discord on Failure
|
||||
needs: [bumper]
|
||||
if: failure()
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discord notification
|
||||
uses: ./.github/actions/discord_notify
|
||||
with:
|
||||
webhook_url: ${{ secrets.DISCORD_WEBHOOK_URL }}
|
||||
2
.github/workflows/documentation.yml
vendored
2
.github/workflows/documentation.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
- uses: jiro4989/setup-nim-action@v1
|
||||
with:
|
||||
nim-version: '1.6.x'
|
||||
nim-version: '2.2.x'
|
||||
|
||||
- name: Generate doc
|
||||
run: |
|
||||
|
||||
2
.github/workflows/examples.yml
vendored
2
.github/workflows/examples.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
shell: bash
|
||||
os: linux
|
||||
cpu: amd64
|
||||
nim_ref: version-1-6
|
||||
nim_ref: version-2-2
|
||||
|
||||
- name: Restore deps from cache
|
||||
id: deps-cache
|
||||
|
||||
94
.github/workflows/performance.yml
vendored
Normal file
94
.github/workflows/performance.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
name: Performance
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
performance:
|
||||
timeout-minutes: 20
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
VACP2P: "vacp2p"
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
PR_NUMBER: ${{ github.event.number }}
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
MARKER: "<!-- perf-summary-marker -->"
|
||||
COMMENT_SUMMARY_PATH: "/tmp/perf-summary.md"
|
||||
SHARED_VOLUME_PATH: "performance/output"
|
||||
DOCKER_STATS_PREFIX: "docker_stats_"
|
||||
PUBLISH_BRANCH_NAME: "performance_plots"
|
||||
CHECKOUT_SUBFOLDER_SUBPLOTS: "subplots"
|
||||
PUBLISH_DIR_PLOTS: "plots"
|
||||
CHECKOUT_SUBFOLDER_HISTORY: "history"
|
||||
PUBLISH_DIR_LATENCY_HISTORY: "latency_history"
|
||||
LATENCY_HISTORY_PATH: "history/latency_history"
|
||||
LATENCY_HISTORY_PREFIX: "pr"
|
||||
LATENCY_HISTORY_PLOT_FILENAME: "latency_history_all_scenarios.png"
|
||||
|
||||
name: "Performance"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build Docker Image with cache
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: performance/Dockerfile
|
||||
tags: test-node:latest
|
||||
load: true
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Run
|
||||
run: |
|
||||
./performance/runner.sh
|
||||
|
||||
- name: Process latency and docker stats
|
||||
uses: ./.github/actions/process_stats
|
||||
|
||||
- name: Publish history
|
||||
if: github.repository_owner == env.VACP2P
|
||||
uses: ./.github/actions/publish_history
|
||||
|
||||
- name: Generate plots
|
||||
if: github.repository_owner == env.VACP2P
|
||||
uses: ./.github/actions/generate_plots
|
||||
|
||||
- name: Post/Update PR comment
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: ./.github/actions/add_comment
|
||||
|
||||
- name: Upload performance artifacts
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: performance-artifacts
|
||||
path: |
|
||||
performance/output/pr*_latency.csv
|
||||
performance/output/*.png
|
||||
history/latency_history/*.png
|
||||
if-no-files-found: ignore
|
||||
retention-days: 7
|
||||
4
.pinned
4
.pinned
@@ -1,5 +1,5 @@
|
||||
bearssl;https://github.com/status-im/nim-bearssl@#34d712933a4e0f91f5e66bc848594a581504a215
|
||||
chronicles;https://github.com/status-im/nim-chronicles@#81a4a7a360c78be9c80c8f735c76b6d4a1517304
|
||||
chronicles;https://github.com/status-im/nim-chronicles@#61759a5e8df8f4d68bcd1b4b8c1adab3e72bbd8d
|
||||
chronos;https://github.com/status-im/nim-chronos@#b55e2816eb45f698ddaca8d8473e401502562db2
|
||||
dnsclient;https://github.com/ba0f3/dnsclient.nim@#23214235d4784d24aceed99bbfe153379ea557c8
|
||||
faststreams;https://github.com/status-im/nim-faststreams@#c51315d0ae5eb2594d0bf41181d0e1aca1b3c01d
|
||||
@@ -8,7 +8,7 @@ json_serialization;https://github.com/status-im/nim-json-serialization@#2b1c5eb1
|
||||
metrics;https://github.com/status-im/nim-metrics@#6142e433fc8ea9b73379770a788017ac528d46ff
|
||||
ngtcp2;https://github.com/status-im/nim-ngtcp2@#9456daa178c655bccd4a3c78ad3b8cce1f0add73
|
||||
nimcrypto;https://github.com/cheatfate/nimcrypto@#19c41d6be4c00b4a2c8000583bd30cf8ceb5f4b1
|
||||
quic;https://github.com/status-im/nim-quic.git@#ca3eda53bee9cef7379be195738ca1490877432f
|
||||
quic;https://github.com/vacp2p/nim-quic@#9370190ded18d78a5a9990f57aa8cbbf947f3891
|
||||
results;https://github.com/arnetheduck/nim-results@#df8113dda4c2d74d460a8fa98252b0b771bf1f27
|
||||
secp256k1;https://github.com/status-im/nim-secp256k1@#f808ed5e7a7bfc42204ec7830f14b7a42b63c284
|
||||
serialization;https://github.com/status-im/nim-serialization@#548d0adc9797a10b2db7f788b804330306293088
|
||||
|
||||
14
README.md
14
README.md
@@ -39,7 +39,7 @@ Learn more about libp2p at [**libp2p.io**](https://libp2p.io) and follow libp2p'
|
||||
|
||||
## Install
|
||||
|
||||
> The currently supported Nim versions are 1.6, 2.0 and 2.2.
|
||||
> The currently supported Nim versions are 2.0 and 2.2.
|
||||
|
||||
```
|
||||
nimble install libp2p
|
||||
@@ -71,6 +71,10 @@ git clone https://github.com/vacp2p/nim-libp2p
|
||||
cd nim-libp2p
|
||||
nimble install -dy
|
||||
```
|
||||
You can use `nix develop` to start a shell with Nim and Nimble.
|
||||
|
||||
nimble 0.20.1 is required for running `testnative`. At time of writing, this is not available in nixpkgs: If using `nix develop`, follow up with `nimble install nimble`, and use that (typically `~/.nimble/bin/nimble`).
|
||||
|
||||
### Testing
|
||||
Run unit tests:
|
||||
```sh
|
||||
@@ -97,6 +101,7 @@ The libp2p implementation in Nim is a work in progress. We welcome contributors
|
||||
- **Add tests**. Help nim-libp2p to be more robust by adding more tests to the [tests folder](tests/).
|
||||
- **Small PRs**. Try to keep PRs atomic and digestible. This makes the review process and pinpointing bugs easier.
|
||||
- **Code format**. Code should be formatted with [nph](https://github.com/arnetheduck/nph) and follow the [Status Nim Style Guide](https://status-im.github.io/nim-style-guide/).
|
||||
- **Join the Conversation**. Connect with other contributors in our [community channel](https://discord.com/channels/1204447718093750272/1351621032263417946). Ask questions, share ideas, get support, and stay informed about the latest updates from the maintainers.
|
||||
|
||||
### Contributors
|
||||
<a href="https://github.com/vacp2p/nim-libp2p/graphs/contributors"><img src="https://contrib.rocks/image?repo=vacp2p/nim-libp2p" alt="nim-libp2p contributors"></a>
|
||||
@@ -119,6 +124,11 @@ Enable quic transport support
|
||||
nim c -d:libp2p_quic_support some_file.nim
|
||||
```
|
||||
|
||||
Enable autotls support
|
||||
```bash
|
||||
nim c -d:libp2p_autotls_support some_file.nim
|
||||
```
|
||||
|
||||
Enable expensive metrics (ie, metrics with per-peer cardinality):
|
||||
```bash
|
||||
nim c -d:libp2p_expensive_metrics some_file.nim
|
||||
@@ -190,7 +200,7 @@ The versioning follows [semver](https://semver.org/), with some additions:
|
||||
- Some of libp2p procedures are marked as `.public.`, they will remain compatible during each `MAJOR` version
|
||||
- The rest of the procedures are considered internal, and can change at any `MINOR` version (but remain compatible for each new `PATCH`)
|
||||
|
||||
We aim to be compatible at all time with at least 2 Nim `MINOR` versions, currently `1.6 & 2.0`
|
||||
We aim to be compatible at all time with at least 2 Nim `MINOR` versions, currently `2.0 & 2.2`
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
## # Circuit Relay example
|
||||
##
|
||||
## Circuit Relay can be used when a node cannot reach another node
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
when not (compileOption("threads")):
|
||||
{.fatal: "Please, compile this program with the --threads:on option!".}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
{.used.}
|
||||
|
||||
import chronos # an efficient library for async
|
||||
import stew/byteutils # various utils
|
||||
import libp2p
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
## # Simple ping tutorial
|
||||
##
|
||||
## Hi all, welcome to the first nim-libp2p tutorial!
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
## # Custom protocol in libp2p
|
||||
##
|
||||
## In the [previous tutorial](tutorial_1_connect.md), we've looked at how to create a simple ping program using the `nim-libp2p`.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
## # Protobuf usage
|
||||
##
|
||||
## In the [previous tutorial](tutorial_2_customproto.md), we created a simple "ping" protocol.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
## # GossipSub
|
||||
##
|
||||
## In this tutorial, we'll build a simple GossipSub network
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
## # Discovery Manager
|
||||
##
|
||||
## In the [previous tutorial](tutorial_4_gossipsub.md), we built a custom protocol using [protobuf](https://developers.google.com/protocol-buffers) and
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
## # Tron example
|
||||
##
|
||||
## In this tutorial, we will create a video game based on libp2p, using
|
||||
|
||||
27
flake.lock
generated
Normal file
27
flake.lock
generated
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1752620740,
|
||||
"narHash": "sha256-f3pO+9lg66mV7IMmmIqG4PL3223TYMlnlw+pnpelbss=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "32a4e87942101f1c9f9865e04dc3ddb175f5f32e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-25.05",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
34
flake.nix
Normal file
34
flake.nix
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
description = "nim-libp2p dev shell flake";
|
||||
|
||||
nixConfig = {
|
||||
extra-substituters = [ "https://nix-cache.status.im/" ];
|
||||
extra-trusted-public-keys = [ "nix-cache.status.im-1:x/93lOfLU+duPplwMSBR+OlY4+mo+dCN7n0mr4oPwgY=" ];
|
||||
};
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05";
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs }:
|
||||
let
|
||||
stableSystems = [
|
||||
"x86_64-linux" "aarch64-linux" "armv7a-linux"
|
||||
"x86_64-darwin" "aarch64-darwin"
|
||||
"x86_64-windows"
|
||||
];
|
||||
forEach = nixpkgs.lib.genAttrs;
|
||||
forAllSystems = forEach stableSystems;
|
||||
pkgsFor = forEach stableSystems (
|
||||
system: import nixpkgs { inherit system; }
|
||||
);
|
||||
in rec {
|
||||
devShells = forAllSystems (system: {
|
||||
default = pkgsFor.${system}.mkShell {
|
||||
nativeBuildInputs = with pkgsFor.${system}; [
|
||||
nim-2_2 nimble openssl.dev
|
||||
];
|
||||
};
|
||||
});
|
||||
};
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
# syntax=docker/dockerfile:1.5-labs
|
||||
FROM nimlang/nim:1.6.16 as builder
|
||||
FROM nimlang/nim:latest as builder
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -7,11 +7,11 @@ COPY .pinned libp2p.nimble nim-libp2p/
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt apt-get update && apt-get install -y libssl-dev
|
||||
|
||||
RUN cd nim-libp2p && nimble install_pinned && nimble install "redis@#b341fe240dbf11c544011dd0e033d3c3acca56af" -y
|
||||
RUN cd nim-libp2p && nimble install_pinned && nimble install redis -y
|
||||
|
||||
COPY . nim-libp2p/
|
||||
|
||||
RUN cd nim-libp2p && nim c --skipParentCfg --NimblePath:./nimbledeps/pkgs --mm:refc -d:chronicles_log_level=DEBUG -d:chronicles_default_output_device=stderr -d:release --threads:off --skipProjCfg -o:hole-punching-tests ./interop/hole-punching/hole_punching.nim
|
||||
RUN cd nim-libp2p && nim c --skipParentCfg --NimblePath:./nimbledeps/pkgs2 --mm:refc -d:chronicles_log_level=DEBUG -d:chronicles_default_output_device=stderr -d:release --threads:off --skipProjCfg -o:hole-punching-tests ./interop/hole-punching/hole_punching.nim
|
||||
|
||||
FROM --platform=linux/amd64 debian:bullseye-slim
|
||||
RUN --mount=type=cache,target=/var/cache/apt apt-get update && apt-get install -y dnsutils jq curl tcpdump iproute2 libssl-dev
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
# syntax=docker/dockerfile:1.5-labs
|
||||
FROM nimlang/nim:1.6.16 as builder
|
||||
FROM nimlang/nim:latest as builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY .pinned libp2p.nimble nim-libp2p/
|
||||
COPY .pinned libp2p.nimble nim-libp2p/
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt apt-get update && apt-get install -y libssl-dev
|
||||
|
||||
RUN cd nim-libp2p && nimble install_pinned && nimble install "redis@#b341fe240dbf11c544011dd0e033d3c3acca56af" -y
|
||||
RUN cd nim-libp2p && nimble install_pinned && nimble install redis -y
|
||||
|
||||
COPY . nim-libp2p/
|
||||
|
||||
RUN \
|
||||
cd nim-libp2p && \
|
||||
nim c --skipProjCfg --skipParentCfg --NimblePath:./nimbledeps/pkgs -p:nim-libp2p --mm:refc -d:libp2p_quic_support -d:chronicles_log_level=WARN -d:chronicles_default_output_device=stderr --threads:off ./interop/transport/main.nim
|
||||
nim c --skipProjCfg --skipParentCfg --NimblePath:./nimbledeps/pkgs2 -p:nim-libp2p --mm:refc -d:libp2p_quic_support -d:chronicles_log_level=WARN -d:chronicles_default_output_device=stderr --threads:off ./interop/transport/main.nim
|
||||
|
||||
ENTRYPOINT ["/app/nim-libp2p/interop/transport/main"]
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
mode = ScriptMode.Verbose
|
||||
|
||||
packageName = "libp2p"
|
||||
version = "1.11.0"
|
||||
version = "1.12.0"
|
||||
author = "Status Research & Development GmbH"
|
||||
description = "LibP2P implementation"
|
||||
license = "MIT"
|
||||
skipDirs = @["tests", "examples", "Nim", "tools", "scripts", "docs"]
|
||||
|
||||
requires "nim >= 1.6.0",
|
||||
requires "nim >= 2.0.0",
|
||||
"nimcrypto >= 0.6.0 & < 0.7.0", "dnsclient >= 0.3.0 & < 0.4.0", "bearssl >= 0.2.5",
|
||||
"chronicles >= 0.10.3 & < 0.11.0", "chronos >= 4.0.4", "metrics", "secp256k1",
|
||||
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.7",
|
||||
"chronicles >= 0.11.0 & < 0.12.0", "chronos >= 4.0.4", "metrics", "secp256k1",
|
||||
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.16",
|
||||
"https://github.com/vacp2p/nim-jwt.git#18f8378de52b241f321c1f9ea905456e89b95c6f"
|
||||
|
||||
let nimc = getEnv("NIMC", "nim") # Which nim compiler to use
|
||||
@@ -30,7 +30,7 @@ proc runTest(filename: string, moreoptions: string = "") =
|
||||
excstr.add(" " & moreoptions & " ")
|
||||
if getEnv("CICOV").len > 0:
|
||||
excstr &= " --nimcache:nimcache/" & filename & "-" & $excstr.hash
|
||||
exec excstr & " -r -d:libp2p_quic_support tests/" & filename
|
||||
exec excstr & " -r -d:libp2p_quic_support -d:libp2p_autotls_support tests/" & filename
|
||||
rmFile "tests/" & filename.toExe
|
||||
|
||||
proc buildSample(filename: string, run = false, extraFlags = "") =
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import options, sequtils, strutils, json, uri
|
||||
import json, uri
|
||||
from times import DateTime, parse
|
||||
import chronos/apps/http/httpclient, jwt, results, bearssl/pem, chronicles
|
||||
import chronos/apps/http/httpclient, results, chronicles
|
||||
|
||||
import ./utils
|
||||
import ../../crypto/crypto
|
||||
@@ -158,352 +158,376 @@ type ACMECertificateResponse* = object
|
||||
rawCertificate*: string
|
||||
certificateExpiry*: DateTime
|
||||
|
||||
template handleError*(msg: string, body: untyped): untyped =
|
||||
try:
|
||||
body
|
||||
except ACMEError as exc:
|
||||
raise exc
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except JsonKindError as exc:
|
||||
raise newException(ACMEError, msg & ": Failed to decode JSON", exc)
|
||||
except ValueError as exc:
|
||||
raise newException(ACMEError, msg & ": Failed to decode JSON", exc)
|
||||
except HttpError as exc:
|
||||
raise newException(ACMEError, msg & ": Failed to connect to ACME server", exc)
|
||||
except CatchableError as exc:
|
||||
raise newException(ACMEError, msg & ": Unexpected error", exc)
|
||||
type ACMECertificate* = object
|
||||
rawCertificate*: string
|
||||
certificateExpiry*: DateTime
|
||||
certKeyPair*: KeyPair
|
||||
|
||||
method post*(
|
||||
self: ACMEApi, uri: Uri, payload: string
|
||||
): Future[HTTPResponse] {.
|
||||
async: (raises: [ACMEError, HttpError, CancelledError]), base
|
||||
.}
|
||||
when defined(libp2p_autotls_support):
|
||||
import options, sequtils, strutils, jwt, bearssl/pem
|
||||
|
||||
method get*(
|
||||
self: ACMEApi, uri: Uri
|
||||
): Future[HTTPResponse] {.
|
||||
async: (raises: [ACMEError, HttpError, CancelledError]), base
|
||||
.}
|
||||
template handleError*(msg: string, body: untyped): untyped =
|
||||
try:
|
||||
body
|
||||
except ACMEError as exc:
|
||||
raise exc
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except JsonKindError as exc:
|
||||
raise newException(ACMEError, msg & ": Failed to decode JSON", exc)
|
||||
except ValueError as exc:
|
||||
raise newException(ACMEError, msg & ": Failed to decode JSON", exc)
|
||||
except HttpError as exc:
|
||||
raise newException(ACMEError, msg & ": Failed to connect to ACME server", exc)
|
||||
except CatchableError as exc:
|
||||
raise newException(ACMEError, msg & ": Unexpected error", exc)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[ACMEApi], acmeServerURL: Uri = parseUri(LetsEncryptURL)
|
||||
): ACMEApi =
|
||||
let session = HttpSessionRef.new()
|
||||
|
||||
ACMEApi(
|
||||
session: session, directory: Opt.none(ACMEDirectory), acmeServerURL: acmeServerURL
|
||||
)
|
||||
|
||||
proc getDirectory(
|
||||
self: ACMEApi
|
||||
): Future[ACMEDirectory] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("getDirectory"):
|
||||
self.directory.valueOr:
|
||||
let acmeResponse = await self.get(self.acmeServerURL / "directory")
|
||||
let directory = acmeResponse.body.to(ACMEDirectory)
|
||||
self.directory = Opt.some(directory)
|
||||
directory
|
||||
|
||||
method requestNonce*(
|
||||
self: ACMEApi
|
||||
): Future[Nonce] {.async: (raises: [ACMEError, CancelledError]), base.} =
|
||||
handleError("requestNonce"):
|
||||
let acmeResponse = await self.get(parseUri((await self.getDirectory()).newNonce))
|
||||
Nonce(acmeResponse.headers.keyOrError("Replay-Nonce"))
|
||||
|
||||
# TODO: save n and e in account so we don't have to recalculate every time
|
||||
proc acmeHeader(
|
||||
self: ACMEApi, uri: Uri, key: KeyPair, needsJwk: bool, kid: Opt[Kid]
|
||||
): Future[ACMERequestHeader] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
if not needsJwk and kid.isNone():
|
||||
raise newException(ACMEError, "kid not set")
|
||||
|
||||
if key.pubkey.scheme != PKScheme.RSA or key.seckey.scheme != PKScheme.RSA:
|
||||
raise newException(ACMEError, "Unsupported signing key type")
|
||||
|
||||
let newNonce = await self.requestNonce()
|
||||
if needsJwk:
|
||||
let pubkey = key.pubkey.rsakey
|
||||
let nArray = @(getArray(pubkey.buffer, pubkey.key.n, pubkey.key.nlen))
|
||||
let eArray = @(getArray(pubkey.buffer, pubkey.key.e, pubkey.key.elen))
|
||||
ACMERequestHeader(
|
||||
kind: ACMEJwkRequest,
|
||||
alg: Alg,
|
||||
typ: "JWT",
|
||||
nonce: newNonce,
|
||||
url: $uri,
|
||||
jwk: JWK(kty: "RSA", n: base64UrlEncode(nArray), e: base64UrlEncode(eArray)),
|
||||
)
|
||||
else:
|
||||
ACMERequestHeader(
|
||||
kind: ACMEKidRequest,
|
||||
alg: Alg,
|
||||
typ: "JWT",
|
||||
nonce: newNonce,
|
||||
url: $uri,
|
||||
kid: kid.get(),
|
||||
)
|
||||
|
||||
method post*(
|
||||
method post*(
|
||||
self: ACMEApi, uri: Uri, payload: string
|
||||
): Future[HTTPResponse] {.
|
||||
): Future[HTTPResponse] {.
|
||||
async: (raises: [ACMEError, HttpError, CancelledError]), base
|
||||
.} =
|
||||
let rawResponse = await HttpClientRequestRef
|
||||
.post(self.session, $uri, body = payload, headers = ACMEHttpHeaders)
|
||||
.get()
|
||||
.send()
|
||||
let body = await rawResponse.getResponseBody()
|
||||
HTTPResponse(body: body, headers: rawResponse.headers)
|
||||
.}
|
||||
|
||||
method get*(
|
||||
method get*(
|
||||
self: ACMEApi, uri: Uri
|
||||
): Future[HTTPResponse] {.
|
||||
): Future[HTTPResponse] {.
|
||||
async: (raises: [ACMEError, HttpError, CancelledError]), base
|
||||
.} =
|
||||
let rawResponse = await HttpClientRequestRef.get(self.session, $uri).get().send()
|
||||
let body = await rawResponse.getResponseBody()
|
||||
HTTPResponse(body: body, headers: rawResponse.headers)
|
||||
.}
|
||||
|
||||
proc createSignedAcmeRequest(
|
||||
self: ACMEApi,
|
||||
uri: Uri,
|
||||
payload: auto,
|
||||
key: KeyPair,
|
||||
needsJwk: bool = false,
|
||||
kid: Opt[Kid] = Opt.none(Kid),
|
||||
): Future[string] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
if key.pubkey.scheme != PKScheme.RSA or key.seckey.scheme != PKScheme.RSA:
|
||||
raise newException(ACMEError, "Unsupported signing key type")
|
||||
proc new*(
|
||||
T: typedesc[ACMEApi], acmeServerURL: Uri = parseUri(LetsEncryptURL)
|
||||
): ACMEApi =
|
||||
let session = HttpSessionRef.new()
|
||||
|
||||
let acmeHeader = await self.acmeHeader(uri, key, needsJwk, kid)
|
||||
handleError("createSignedAcmeRequest"):
|
||||
var token = toJWT(%*{"header": acmeHeader, "claims": payload})
|
||||
let derPrivKey = key.seckey.rsakey.getBytes.get
|
||||
let pemPrivKey: string = pemEncode(derPrivKey, "PRIVATE KEY")
|
||||
token.sign(pemPrivKey)
|
||||
$token.toFlattenedJson()
|
||||
|
||||
proc requestRegister*(
|
||||
self: ACMEApi, key: KeyPair
|
||||
): Future[ACMERegisterResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let registerRequest = ACMERegisterRequest(termsOfServiceAgreed: true)
|
||||
handleError("acmeRegister"):
|
||||
let payload = await self.createSignedAcmeRequest(
|
||||
parseUri((await self.getDirectory()).newAccount),
|
||||
registerRequest,
|
||||
key,
|
||||
needsJwk = true,
|
||||
)
|
||||
let acmeResponse =
|
||||
await self.post(parseUri((await self.getDirectory()).newAccount), payload)
|
||||
let acmeResponseBody = acmeResponse.body.to(ACMERegisterResponseBody)
|
||||
|
||||
ACMERegisterResponse(
|
||||
status: acmeResponseBody.status, kid: acmeResponse.headers.keyOrError("location")
|
||||
ACMEApi(
|
||||
session: session, directory: Opt.none(ACMEDirectory), acmeServerURL: acmeServerURL
|
||||
)
|
||||
|
||||
proc requestNewOrder*(
|
||||
self: ACMEApi, domains: seq[Domain], key: KeyPair, kid: Kid
|
||||
): Future[ACMEChallengeResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
# request challenge from ACME server
|
||||
let orderRequest = ACMEChallengeRequest(
|
||||
identifiers: domains.mapIt(ACMEChallengeIdentifier(`type`: "dns", value: it))
|
||||
)
|
||||
handleError("requestNewOrder"):
|
||||
let payload = await self.createSignedAcmeRequest(
|
||||
parseUri((await self.getDirectory()).newOrder),
|
||||
orderRequest,
|
||||
key,
|
||||
kid = Opt.some(kid),
|
||||
)
|
||||
let acmeResponse =
|
||||
await self.post(parseUri((await self.getDirectory()).newOrder), payload)
|
||||
let challengeResponseBody = acmeResponse.body.to(ACMEChallengeResponseBody)
|
||||
if challengeResponseBody.authorizations.len == 0:
|
||||
raise newException(ACMEError, "Authorizations field is empty")
|
||||
ACMEChallengeResponse(
|
||||
status: challengeResponseBody.status,
|
||||
authorizations: challengeResponseBody.authorizations,
|
||||
finalize: challengeResponseBody.finalize,
|
||||
order: acmeResponse.headers.keyOrError("location"),
|
||||
)
|
||||
proc getDirectory(
|
||||
self: ACMEApi
|
||||
): Future[ACMEDirectory] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("getDirectory"):
|
||||
self.directory.valueOr:
|
||||
let acmeResponse = await self.get(self.acmeServerURL / "directory")
|
||||
let directory = acmeResponse.body.to(ACMEDirectory)
|
||||
self.directory = Opt.some(directory)
|
||||
directory
|
||||
|
||||
proc requestAuthorizations*(
|
||||
self: ACMEApi, authorizations: seq[Authorization], key: KeyPair, kid: Kid
|
||||
): Future[ACMEAuthorizationsResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestAuthorizations"):
|
||||
doAssert authorizations.len > 0
|
||||
let acmeResponse = await self.get(parseUri(authorizations[0]))
|
||||
acmeResponse.body.to(ACMEAuthorizationsResponse)
|
||||
method requestNonce*(
|
||||
self: ACMEApi
|
||||
): Future[Nonce] {.async: (raises: [ACMEError, CancelledError]), base.} =
|
||||
handleError("requestNonce"):
|
||||
let acmeResponse = await self.get(parseUri((await self.getDirectory()).newNonce))
|
||||
Nonce(acmeResponse.headers.keyOrError("Replay-Nonce"))
|
||||
|
||||
proc requestChallenge*(
|
||||
self: ACMEApi, domains: seq[Domain], key: KeyPair, kid: Kid
|
||||
): Future[ACMEChallengeResponseWrapper] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let orderResponse = await self.requestNewOrder(domains, key, kid)
|
||||
if orderResponse.status != ACMEOrderStatus.PENDING and
|
||||
orderResponse.status != ACMEOrderStatus.READY:
|
||||
# ready is a valid status when renewing certs before expiry
|
||||
raise newException(ACMEError, "Invalid new order status: " & $orderResponse.status)
|
||||
# TODO: save n and e in account so we don't have to recalculate every time
|
||||
proc acmeHeader(
|
||||
self: ACMEApi, uri: Uri, key: KeyPair, needsJwk: bool, kid: Opt[Kid]
|
||||
): Future[ACMERequestHeader] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
if not needsJwk and kid.isNone():
|
||||
raise newException(ACMEError, "kid not set")
|
||||
|
||||
let authorizationsResponse =
|
||||
await self.requestAuthorizations(orderResponse.authorizations, key, kid)
|
||||
if authorizationsResponse.challenges.len == 0:
|
||||
raise newException(ACMEError, "No challenges received")
|
||||
if key.pubkey.scheme != PKScheme.RSA or key.seckey.scheme != PKScheme.RSA:
|
||||
raise newException(ACMEError, "Unsupported signing key type")
|
||||
|
||||
return ACMEChallengeResponseWrapper(
|
||||
finalize: orderResponse.finalize,
|
||||
order: orderResponse.order,
|
||||
dns01: authorizationsResponse.challenges.filterIt(
|
||||
it.`type` == ACMEChallengeType.DNS01
|
||||
)[0],
|
||||
# getting the first element is safe since we checked that authorizationsResponse.challenges.len != 0
|
||||
)
|
||||
|
||||
proc requestCheck*(
|
||||
self: ACMEApi, checkURL: Uri, checkKind: ACMECheckKind, key: KeyPair, kid: Kid
|
||||
): Future[ACMECheckResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestCheck"):
|
||||
let acmeResponse = await self.get(checkURL)
|
||||
let retryAfter =
|
||||
try:
|
||||
parseInt(acmeResponse.headers.keyOrError("Retry-After")).seconds
|
||||
except ValueError:
|
||||
DefaultChalCompletedRetryTime
|
||||
|
||||
case checkKind
|
||||
of ACMEOrderCheck:
|
||||
try:
|
||||
ACMECheckResponse(
|
||||
kind: checkKind,
|
||||
orderStatus: parseEnum[ACMEOrderStatus](acmeResponse.body["status"].getStr),
|
||||
retryAfter: retryAfter,
|
||||
)
|
||||
except ValueError:
|
||||
raise newException(
|
||||
ACMEError, "Invalid order status: " & acmeResponse.body["status"].getStr
|
||||
)
|
||||
of ACMEChallengeCheck:
|
||||
try:
|
||||
ACMECheckResponse(
|
||||
kind: checkKind,
|
||||
chalStatus: parseEnum[ACMEChallengeStatus](acmeResponse.body["status"].getStr),
|
||||
retryAfter: retryAfter,
|
||||
)
|
||||
except ValueError:
|
||||
raise newException(
|
||||
ACMEError, "Invalid order status: " & acmeResponse.body["status"].getStr
|
||||
)
|
||||
|
||||
proc sendChallengeCompleted*(
|
||||
self: ACMEApi, chalURL: Uri, key: KeyPair, kid: Kid
|
||||
): Future[ACMECompletedResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("sendChallengeCompleted"):
|
||||
let payload =
|
||||
await self.createSignedAcmeRequest(chalURL, %*{}, key, kid = Opt.some(kid))
|
||||
let acmeResponse = await self.post(chalURL, payload)
|
||||
acmeResponse.body.to(ACMECompletedResponse)
|
||||
|
||||
proc checkChallengeCompleted*(
|
||||
self: ACMEApi,
|
||||
checkURL: Uri,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultChalCompletedRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
for i in 0 .. retries:
|
||||
let checkResponse = await self.requestCheck(checkURL, ACMEChallengeCheck, key, kid)
|
||||
case checkResponse.chalStatus
|
||||
of ACMEChallengeStatus.PENDING:
|
||||
await sleepAsync(checkResponse.retryAfter) # try again after some delay
|
||||
of ACMEChallengeStatus.VALID:
|
||||
return true
|
||||
else:
|
||||
raise newException(
|
||||
ACMEError,
|
||||
"Failed challenge completion: expected 'valid', got '" &
|
||||
$checkResponse.chalStatus & "'",
|
||||
let newNonce = await self.requestNonce()
|
||||
if needsJwk:
|
||||
let pubkey = key.pubkey.rsakey
|
||||
let nArray = @(getArray(pubkey.buffer, pubkey.key.n, pubkey.key.nlen))
|
||||
let eArray = @(getArray(pubkey.buffer, pubkey.key.e, pubkey.key.elen))
|
||||
ACMERequestHeader(
|
||||
kind: ACMEJwkRequest,
|
||||
alg: Alg,
|
||||
typ: "JWT",
|
||||
nonce: newNonce,
|
||||
url: $uri,
|
||||
jwk: JWK(kty: "RSA", n: base64UrlEncode(nArray), e: base64UrlEncode(eArray)),
|
||||
)
|
||||
return false
|
||||
|
||||
proc completeChallenge*(
|
||||
self: ACMEApi,
|
||||
chalURL: Uri,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultChalCompletedRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let completedResponse = await self.sendChallengeCompleted(chalURL, key, kid)
|
||||
# check until acme server is done (poll validation)
|
||||
return await self.checkChallengeCompleted(chalURL, key, kid, retries = retries)
|
||||
|
||||
proc requestFinalize*(
|
||||
self: ACMEApi, domain: Domain, finalize: Uri, key: KeyPair, kid: Kid
|
||||
): Future[ACMEFinalizeResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestFinalize"):
|
||||
let payload = await self.createSignedAcmeRequest(
|
||||
finalize, %*{"csr": createCSR(domain)}, key, kid = Opt.some(kid)
|
||||
)
|
||||
let acmeResponse = await self.post(finalize, payload)
|
||||
# server responds with updated order response
|
||||
acmeResponse.body.to(ACMEFinalizeResponse)
|
||||
|
||||
proc checkCertFinalized*(
|
||||
self: ACMEApi,
|
||||
order: Uri,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultChalCompletedRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
for i in 0 .. retries:
|
||||
let checkResponse = await self.requestCheck(order, ACMEOrderCheck, key, kid)
|
||||
case checkResponse.orderStatus
|
||||
of ACMEOrderStatus.VALID:
|
||||
return true
|
||||
of ACMEOrderStatus.PROCESSING:
|
||||
await sleepAsync(checkResponse.retryAfter) # try again after some delay
|
||||
else:
|
||||
error "Failed certificate finalization",
|
||||
description = "expected 'valid', got '" & $checkResponse.orderStatus & "'"
|
||||
return false # do not try again
|
||||
ACMERequestHeader(
|
||||
kind: ACMEKidRequest,
|
||||
alg: Alg,
|
||||
typ: "JWT",
|
||||
nonce: newNonce,
|
||||
url: $uri,
|
||||
kid: kid.get(),
|
||||
)
|
||||
|
||||
return false
|
||||
|
||||
proc certificateFinalized*(
|
||||
self: ACMEApi,
|
||||
domain: Domain,
|
||||
finalize: Uri,
|
||||
order: Uri,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultFinalizeRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let finalizeResponse = await self.requestFinalize(domain, finalize, key, kid)
|
||||
# keep checking order until cert is valid (done)
|
||||
return await self.checkCertFinalized(order, key, kid, retries = retries)
|
||||
|
||||
proc requestGetOrder*(
|
||||
self: ACMEApi, order: Uri
|
||||
): Future[ACMEOrderResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestGetOrder"):
|
||||
let acmeResponse = await self.get(order)
|
||||
acmeResponse.body.to(ACMEOrderResponse)
|
||||
|
||||
proc downloadCertificate*(
|
||||
self: ACMEApi, order: Uri
|
||||
): Future[ACMECertificateResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let orderResponse = await self.requestGetOrder(order)
|
||||
|
||||
handleError("downloadCertificate"):
|
||||
method post*(
|
||||
self: ACMEApi, uri: Uri, payload: string
|
||||
): Future[HTTPResponse] {.
|
||||
async: (raises: [ACMEError, HttpError, CancelledError]), base
|
||||
.} =
|
||||
let rawResponse = await HttpClientRequestRef
|
||||
.get(self.session, orderResponse.certificate)
|
||||
.post(self.session, $uri, body = payload, headers = ACMEHttpHeaders)
|
||||
.get()
|
||||
.send()
|
||||
ACMECertificateResponse(
|
||||
rawCertificate: bytesToString(await rawResponse.getBodyBytes()),
|
||||
certificateExpiry: parse(orderResponse.expires, "yyyy-MM-dd'T'HH:mm:ss'Z'"),
|
||||
let body = await rawResponse.getResponseBody()
|
||||
HTTPResponse(body: body, headers: rawResponse.headers)
|
||||
|
||||
method get*(
|
||||
self: ACMEApi, uri: Uri
|
||||
): Future[HTTPResponse] {.
|
||||
async: (raises: [ACMEError, HttpError, CancelledError]), base
|
||||
.} =
|
||||
let rawResponse = await HttpClientRequestRef.get(self.session, $uri).get().send()
|
||||
let body = await rawResponse.getResponseBody()
|
||||
HTTPResponse(body: body, headers: rawResponse.headers)
|
||||
|
||||
proc createSignedAcmeRequest(
|
||||
self: ACMEApi,
|
||||
uri: Uri,
|
||||
payload: auto,
|
||||
key: KeyPair,
|
||||
needsJwk: bool = false,
|
||||
kid: Opt[Kid] = Opt.none(Kid),
|
||||
): Future[string] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
if key.pubkey.scheme != PKScheme.RSA or key.seckey.scheme != PKScheme.RSA:
|
||||
raise newException(ACMEError, "Unsupported signing key type")
|
||||
|
||||
let acmeHeader = await self.acmeHeader(uri, key, needsJwk, kid)
|
||||
handleError("createSignedAcmeRequest"):
|
||||
var token = toJWT(%*{"header": acmeHeader, "claims": payload})
|
||||
let derPrivKey = key.seckey.rsakey.getBytes.get
|
||||
let pemPrivKey: string = pemEncode(derPrivKey, "PRIVATE KEY")
|
||||
token.sign(pemPrivKey)
|
||||
$token.toFlattenedJson()
|
||||
|
||||
proc requestRegister*(
|
||||
self: ACMEApi, key: KeyPair
|
||||
): Future[ACMERegisterResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let registerRequest = ACMERegisterRequest(termsOfServiceAgreed: true)
|
||||
handleError("acmeRegister"):
|
||||
let payload = await self.createSignedAcmeRequest(
|
||||
parseUri((await self.getDirectory()).newAccount),
|
||||
registerRequest,
|
||||
key,
|
||||
needsJwk = true,
|
||||
)
|
||||
let acmeResponse =
|
||||
await self.post(parseUri((await self.getDirectory()).newAccount), payload)
|
||||
let acmeResponseBody = acmeResponse.body.to(ACMERegisterResponseBody)
|
||||
|
||||
ACMERegisterResponse(
|
||||
status: acmeResponseBody.status,
|
||||
kid: acmeResponse.headers.keyOrError("location"),
|
||||
)
|
||||
|
||||
proc requestNewOrder*(
|
||||
self: ACMEApi, domains: seq[Domain], key: KeyPair, kid: Kid
|
||||
): Future[ACMEChallengeResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
# request challenge from ACME server
|
||||
let orderRequest = ACMEChallengeRequest(
|
||||
identifiers: domains.mapIt(ACMEChallengeIdentifier(`type`: "dns", value: it))
|
||||
)
|
||||
handleError("requestNewOrder"):
|
||||
let payload = await self.createSignedAcmeRequest(
|
||||
parseUri((await self.getDirectory()).newOrder),
|
||||
orderRequest,
|
||||
key,
|
||||
kid = Opt.some(kid),
|
||||
)
|
||||
let acmeResponse =
|
||||
await self.post(parseUri((await self.getDirectory()).newOrder), payload)
|
||||
let challengeResponseBody = acmeResponse.body.to(ACMEChallengeResponseBody)
|
||||
if challengeResponseBody.authorizations.len == 0:
|
||||
raise newException(ACMEError, "Authorizations field is empty")
|
||||
ACMEChallengeResponse(
|
||||
status: challengeResponseBody.status,
|
||||
authorizations: challengeResponseBody.authorizations,
|
||||
finalize: challengeResponseBody.finalize,
|
||||
order: acmeResponse.headers.keyOrError("location"),
|
||||
)
|
||||
|
||||
proc requestAuthorizations*(
|
||||
self: ACMEApi, authorizations: seq[Authorization], key: KeyPair, kid: Kid
|
||||
): Future[ACMEAuthorizationsResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestAuthorizations"):
|
||||
doAssert authorizations.len > 0
|
||||
let acmeResponse = await self.get(parseUri(authorizations[0]))
|
||||
acmeResponse.body.to(ACMEAuthorizationsResponse)
|
||||
|
||||
proc requestChallenge*(
|
||||
self: ACMEApi, domains: seq[Domain], key: KeyPair, kid: Kid
|
||||
): Future[ACMEChallengeResponseWrapper] {.
|
||||
async: (raises: [ACMEError, CancelledError])
|
||||
.} =
|
||||
let orderResponse = await self.requestNewOrder(domains, key, kid)
|
||||
if orderResponse.status != ACMEOrderStatus.PENDING and
|
||||
orderResponse.status != ACMEOrderStatus.READY:
|
||||
# ready is a valid status when renewing certs before expiry
|
||||
raise
|
||||
newException(ACMEError, "Invalid new order status: " & $orderResponse.status)
|
||||
|
||||
let authorizationsResponse =
|
||||
await self.requestAuthorizations(orderResponse.authorizations, key, kid)
|
||||
if authorizationsResponse.challenges.len == 0:
|
||||
raise newException(ACMEError, "No challenges received")
|
||||
|
||||
return ACMEChallengeResponseWrapper(
|
||||
finalize: orderResponse.finalize,
|
||||
order: orderResponse.order,
|
||||
dns01: authorizationsResponse.challenges.filterIt(
|
||||
it.`type` == ACMEChallengeType.DNS01
|
||||
)[0],
|
||||
# getting the first element is safe since we checked that authorizationsResponse.challenges.len != 0
|
||||
)
|
||||
|
||||
proc close*(self: ACMEApi) {.async: (raises: [CancelledError]).} =
|
||||
await self.session.closeWait()
|
||||
proc requestCheck*(
|
||||
self: ACMEApi, checkURL: Uri, checkKind: ACMECheckKind, key: KeyPair, kid: Kid
|
||||
): Future[ACMECheckResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestCheck"):
|
||||
let acmeResponse = await self.get(checkURL)
|
||||
let retryAfter =
|
||||
try:
|
||||
parseInt(acmeResponse.headers.keyOrError("Retry-After")).seconds
|
||||
except ValueError:
|
||||
DefaultChalCompletedRetryTime
|
||||
|
||||
case checkKind
|
||||
of ACMEOrderCheck:
|
||||
try:
|
||||
ACMECheckResponse(
|
||||
kind: checkKind,
|
||||
orderStatus: parseEnum[ACMEOrderStatus](acmeResponse.body["status"].getStr),
|
||||
retryAfter: retryAfter,
|
||||
)
|
||||
except ValueError:
|
||||
raise newException(
|
||||
ACMEError, "Invalid order status: " & acmeResponse.body["status"].getStr
|
||||
)
|
||||
of ACMEChallengeCheck:
|
||||
try:
|
||||
ACMECheckResponse(
|
||||
kind: checkKind,
|
||||
chalStatus:
|
||||
parseEnum[ACMEChallengeStatus](acmeResponse.body["status"].getStr),
|
||||
retryAfter: retryAfter,
|
||||
)
|
||||
except ValueError:
|
||||
raise newException(
|
||||
ACMEError, "Invalid order status: " & acmeResponse.body["status"].getStr
|
||||
)
|
||||
|
||||
proc sendChallengeCompleted*(
|
||||
self: ACMEApi, chalURL: Uri, key: KeyPair, kid: Kid
|
||||
): Future[ACMECompletedResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("sendChallengeCompleted"):
|
||||
let payload =
|
||||
await self.createSignedAcmeRequest(chalURL, %*{}, key, kid = Opt.some(kid))
|
||||
let acmeResponse = await self.post(chalURL, payload)
|
||||
acmeResponse.body.to(ACMECompletedResponse)
|
||||
|
||||
proc checkChallengeCompleted*(
|
||||
self: ACMEApi,
|
||||
checkURL: Uri,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultChalCompletedRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
for i in 0 .. retries:
|
||||
let checkResponse =
|
||||
await self.requestCheck(checkURL, ACMEChallengeCheck, key, kid)
|
||||
case checkResponse.chalStatus
|
||||
of ACMEChallengeStatus.PENDING:
|
||||
await sleepAsync(checkResponse.retryAfter) # try again after some delay
|
||||
of ACMEChallengeStatus.VALID:
|
||||
return true
|
||||
else:
|
||||
raise newException(
|
||||
ACMEError,
|
||||
"Failed challenge completion: expected 'valid', got '" &
|
||||
$checkResponse.chalStatus & "'",
|
||||
)
|
||||
return false
|
||||
|
||||
proc completeChallenge*(
|
||||
self: ACMEApi,
|
||||
chalURL: Uri,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultChalCompletedRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let completedResponse = await self.sendChallengeCompleted(chalURL, key, kid)
|
||||
# check until acme server is done (poll validation)
|
||||
return await self.checkChallengeCompleted(chalURL, key, kid, retries = retries)
|
||||
|
||||
proc requestFinalize*(
|
||||
self: ACMEApi,
|
||||
domain: Domain,
|
||||
finalize: Uri,
|
||||
certKeyPair: KeyPair,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
): Future[ACMEFinalizeResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestFinalize"):
|
||||
let payload = await self.createSignedAcmeRequest(
|
||||
finalize, %*{"csr": createCSR(domain, certKeyPair)}, key, kid = Opt.some(kid)
|
||||
)
|
||||
let acmeResponse = await self.post(finalize, payload)
|
||||
# server responds with updated order response
|
||||
acmeResponse.body.to(ACMEFinalizeResponse)
|
||||
|
||||
proc checkCertFinalized*(
|
||||
self: ACMEApi,
|
||||
order: Uri,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultChalCompletedRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
for i in 0 .. retries:
|
||||
let checkResponse = await self.requestCheck(order, ACMEOrderCheck, key, kid)
|
||||
case checkResponse.orderStatus
|
||||
of ACMEOrderStatus.VALID:
|
||||
return true
|
||||
of ACMEOrderStatus.PROCESSING:
|
||||
await sleepAsync(checkResponse.retryAfter) # try again after some delay
|
||||
else:
|
||||
error "Failed certificate finalization",
|
||||
description = "expected 'valid', got '" & $checkResponse.orderStatus & "'"
|
||||
return false # do not try again
|
||||
|
||||
return false
|
||||
|
||||
proc certificateFinalized*(
|
||||
self: ACMEApi,
|
||||
domain: Domain,
|
||||
finalize: Uri,
|
||||
order: Uri,
|
||||
certKeyPair: KeyPair,
|
||||
key: KeyPair,
|
||||
kid: Kid,
|
||||
retries: int = DefaultFinalizeRetries,
|
||||
): Future[bool] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let finalizeResponse =
|
||||
await self.requestFinalize(domain, finalize, certKeyPair, key, kid)
|
||||
# keep checking order until cert is valid (done)
|
||||
return await self.checkCertFinalized(order, key, kid, retries = retries)
|
||||
|
||||
proc requestGetOrder*(
|
||||
self: ACMEApi, order: Uri
|
||||
): Future[ACMEOrderResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
handleError("requestGetOrder"):
|
||||
let acmeResponse = await self.get(order)
|
||||
acmeResponse.body.to(ACMEOrderResponse)
|
||||
|
||||
proc downloadCertificate*(
|
||||
self: ACMEApi, order: Uri
|
||||
): Future[ACMECertificateResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let orderResponse = await self.requestGetOrder(order)
|
||||
|
||||
handleError("downloadCertificate"):
|
||||
let rawResponse = await HttpClientRequestRef
|
||||
.get(self.session, orderResponse.certificate)
|
||||
.get()
|
||||
.send()
|
||||
ACMECertificateResponse(
|
||||
rawCertificate: bytesToString(await rawResponse.getBodyBytes()),
|
||||
certificateExpiry: parse(orderResponse.expires, "yyyy-MM-dd'T'HH:mm:ss'Z'"),
|
||||
)
|
||||
|
||||
proc close*(self: ACMEApi) {.async: (raises: [CancelledError]).} =
|
||||
await self.session.closeWait()
|
||||
|
||||
else:
|
||||
{.hint: "autotls disabled. Use -d:libp2p_autotls_support".}
|
||||
|
||||
@@ -9,12 +9,9 @@
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import uri
|
||||
import chronos, results, chronicles, stew/byteutils
|
||||
|
||||
import ./api, ./utils
|
||||
import chronicles
|
||||
import ../../crypto/crypto
|
||||
import ../../crypto/rsa
|
||||
import ./api
|
||||
|
||||
export api
|
||||
|
||||
@@ -28,59 +25,74 @@ type ACMEClient* = ref object
|
||||
logScope:
|
||||
topics = "libp2p acme client"
|
||||
|
||||
proc new*(
|
||||
T: typedesc[ACMEClient],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
api: ACMEApi = ACMEApi.new(acmeServerURL = parseUri(LetsEncryptURL)),
|
||||
key: Opt[KeyPair] = Opt.none(KeyPair),
|
||||
kid: Kid = Kid(""),
|
||||
): T {.raises: [].} =
|
||||
let key = key.valueOr:
|
||||
KeyPair.random(PKScheme.RSA, rng[]).get()
|
||||
T(api: api, key: key, kid: kid)
|
||||
when defined(libp2p_autotls_support):
|
||||
import uri
|
||||
import chronos, results, stew/byteutils
|
||||
import ../../crypto/rsa
|
||||
import ./utils
|
||||
|
||||
proc getOrInitKid*(
|
||||
self: ACMEClient
|
||||
): Future[Kid] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
if self.kid.len == 0:
|
||||
let registerResponse = await self.api.requestRegister(self.key)
|
||||
self.kid = registerResponse.kid
|
||||
return self.kid
|
||||
proc new*(
|
||||
T: typedesc[ACMEClient],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
api: ACMEApi = ACMEApi.new(acmeServerURL = parseUri(LetsEncryptURL)),
|
||||
key: Opt[KeyPair] = Opt.none(KeyPair),
|
||||
kid: Kid = Kid(""),
|
||||
): T {.raises: [].} =
|
||||
let key = key.valueOr:
|
||||
KeyPair.random(PKScheme.RSA, rng[]).get()
|
||||
T(api: api, key: key, kid: kid)
|
||||
|
||||
proc genKeyAuthorization*(self: ACMEClient, token: string): KeyAuthorization =
|
||||
base64UrlEncode(@(sha256.digest((token & "." & thumbprint(self.key)).toBytes).data))
|
||||
proc getOrInitKid*(
|
||||
self: ACMEClient
|
||||
): Future[Kid] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
if self.kid.len == 0:
|
||||
let registerResponse = await self.api.requestRegister(self.key)
|
||||
self.kid = registerResponse.kid
|
||||
return self.kid
|
||||
|
||||
proc getChallenge*(
|
||||
self: ACMEClient, domains: seq[api.Domain]
|
||||
): Future[ACMEChallengeResponseWrapper] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
await self.api.requestChallenge(domains, self.key, await self.getOrInitKid())
|
||||
proc genKeyAuthorization*(self: ACMEClient, token: string): KeyAuthorization =
|
||||
base64UrlEncode(@(sha256.digest((token & "." & thumbprint(self.key)).toBytes).data))
|
||||
|
||||
proc getCertificate*(
|
||||
self: ACMEClient, domain: api.Domain, challenge: ACMEChallengeResponseWrapper
|
||||
): Future[ACMECertificateResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let chalURL = parseUri(challenge.dns01.url)
|
||||
let orderURL = parseUri(challenge.order)
|
||||
let finalizeURL = parseUri(challenge.finalize)
|
||||
trace "sending challenge completed notification"
|
||||
discard
|
||||
await self.api.sendChallengeCompleted(chalURL, self.key, await self.getOrInitKid())
|
||||
proc getChallenge*(
|
||||
self: ACMEClient, domains: seq[api.Domain]
|
||||
): Future[ACMEChallengeResponseWrapper] {.
|
||||
async: (raises: [ACMEError, CancelledError])
|
||||
.} =
|
||||
await self.api.requestChallenge(domains, self.key, await self.getOrInitKid())
|
||||
|
||||
trace "checking for completed challenge"
|
||||
let completed =
|
||||
await self.api.checkChallengeCompleted(chalURL, self.key, await self.getOrInitKid())
|
||||
if not completed:
|
||||
raise
|
||||
newException(ACMEError, "Failed to signal ACME server about challenge completion")
|
||||
proc getCertificate*(
|
||||
self: ACMEClient,
|
||||
domain: api.Domain,
|
||||
certKeyPair: KeyPair,
|
||||
challenge: ACMEChallengeResponseWrapper,
|
||||
): Future[ACMECertificateResponse] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
let chalURL = parseUri(challenge.dns01.url)
|
||||
let orderURL = parseUri(challenge.order)
|
||||
let finalizeURL = parseUri(challenge.finalize)
|
||||
trace "Sending challenge completed notification"
|
||||
discard await self.api.sendChallengeCompleted(
|
||||
chalURL, self.key, await self.getOrInitKid()
|
||||
)
|
||||
|
||||
trace "waiting for certificate to be finalized"
|
||||
let finalized = await self.api.certificateFinalized(
|
||||
domain, finalizeURL, orderURL, self.key, await self.getOrInitKid()
|
||||
)
|
||||
if not finalized:
|
||||
raise newException(ACMEError, "Failed to finalize certificate for domain " & domain)
|
||||
trace "Checking for completed challenge"
|
||||
let completed = await self.api.checkChallengeCompleted(
|
||||
chalURL, self.key, await self.getOrInitKid()
|
||||
)
|
||||
if not completed:
|
||||
raise newException(
|
||||
ACMEError, "Failed to signal ACME server about challenge completion"
|
||||
)
|
||||
|
||||
trace "downloading certificate"
|
||||
await self.api.downloadCertificate(orderURL)
|
||||
trace "Waiting for certificate to be finalized"
|
||||
let finalized = await self.api.certificateFinalized(
|
||||
domain, finalizeURL, orderURL, certKeyPair, self.key, await self.getOrInitKid()
|
||||
)
|
||||
if not finalized:
|
||||
raise
|
||||
newException(ACMEError, "Failed to finalize certificate for domain " & domain)
|
||||
|
||||
proc close*(self: ACMEClient) {.async: (raises: [CancelledError]).} =
|
||||
await self.api.close()
|
||||
trace "Downloading certificate"
|
||||
await self.api.downloadCertificate(orderURL)
|
||||
|
||||
proc close*(self: ACMEClient) {.async: (raises: [CancelledError]).} =
|
||||
await self.api.close()
|
||||
|
||||
@@ -21,19 +21,20 @@ proc new*(
|
||||
acmeServerURL: parseUri(LetsEncryptURL),
|
||||
)
|
||||
|
||||
method requestNonce*(
|
||||
self: MockACMEApi
|
||||
): Future[Nonce] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
return $self.acmeServerURL & "/acme/1234"
|
||||
when defined(libp2p_autotls_support):
|
||||
method requestNonce*(
|
||||
self: MockACMEApi
|
||||
): Future[Nonce] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
return $self.acmeServerURL & "/acme/1234"
|
||||
|
||||
method post*(
|
||||
self: MockACMEApi, uri: Uri, payload: string
|
||||
): Future[HTTPResponse] {.async: (raises: [ACMEError, HttpError, CancelledError]).} =
|
||||
result = self.mockedResponses[0]
|
||||
self.mockedResponses.delete(0)
|
||||
method post*(
|
||||
self: MockACMEApi, uri: Uri, payload: string
|
||||
): Future[HTTPResponse] {.async: (raises: [ACMEError, HttpError, CancelledError]).} =
|
||||
result = self.mockedResponses[0]
|
||||
self.mockedResponses.delete(0)
|
||||
|
||||
method get*(
|
||||
self: MockACMEApi, uri: Uri
|
||||
): Future[HTTPResponse] {.async: (raises: [ACMEError, HttpError, CancelledError]).} =
|
||||
result = self.mockedResponses[0]
|
||||
self.mockedResponses.delete(0)
|
||||
method get*(
|
||||
self: MockACMEApi, uri: Uri
|
||||
): Future[HTTPResponse] {.async: (raises: [ACMEError, HttpError, CancelledError]).} =
|
||||
result = self.mockedResponses[0]
|
||||
self.mockedResponses.delete(0)
|
||||
|
||||
@@ -1,67 +1,73 @@
|
||||
import base64, strutils, chronos/apps/http/httpclient, json
|
||||
import ../../errors
|
||||
import ../../transports/tls/certificate_ffi
|
||||
import ../../transports/tls/certificate
|
||||
import ../../crypto/crypto
|
||||
import ../../crypto/rsa
|
||||
|
||||
type ACMEError* = object of LPError
|
||||
|
||||
proc keyOrError*(table: HttpTable, key: string): string {.raises: [ValueError].} =
|
||||
if not table.contains(key):
|
||||
raise newException(ValueError, "key " & key & " not present in headers")
|
||||
table.getString(key)
|
||||
when defined(libp2p_autotls_support):
|
||||
import base64, strutils, chronos/apps/http/httpclient, json
|
||||
import ../../transports/tls/certificate_ffi
|
||||
import ../../transports/tls/certificate
|
||||
import ../../crypto/crypto
|
||||
import ../../crypto/rsa
|
||||
|
||||
proc base64UrlEncode*(data: seq[byte]): string =
|
||||
## Encodes data using base64url (RFC 4648 §5) — no padding, URL-safe
|
||||
var encoded = base64.encode(data, safe = true)
|
||||
encoded.removeSuffix("=")
|
||||
encoded.removeSuffix("=")
|
||||
return encoded
|
||||
proc keyOrError*(table: HttpTable, key: string): string {.raises: [ValueError].} =
|
||||
if not table.contains(key):
|
||||
raise newException(ValueError, "key " & key & " not present in headers")
|
||||
table.getString(key)
|
||||
|
||||
proc thumbprint*(key: KeyPair): string =
|
||||
doAssert key.seckey.scheme == PKScheme.RSA, "unsupported keytype"
|
||||
let pubkey = key.pubkey.rsakey
|
||||
let nArray = @(getArray(pubkey.buffer, pubkey.key.n, pubkey.key.nlen))
|
||||
let eArray = @(getArray(pubkey.buffer, pubkey.key.e, pubkey.key.elen))
|
||||
proc base64UrlEncode*(data: seq[byte]): string =
|
||||
## Encodes data using base64url (RFC 4648 §5) — no padding, URL-safe
|
||||
var encoded = base64.encode(data, safe = true)
|
||||
encoded.removeSuffix("=")
|
||||
encoded.removeSuffix("=")
|
||||
return encoded
|
||||
|
||||
let n = base64UrlEncode(nArray)
|
||||
let e = base64UrlEncode(eArray)
|
||||
let keyJson = %*{"e": e, "kty": "RSA", "n": n}
|
||||
let digest = sha256.digest($keyJson)
|
||||
return base64UrlEncode(@(digest.data))
|
||||
proc thumbprint*(key: KeyPair): string =
|
||||
doAssert key.seckey.scheme == PKScheme.RSA, "unsupported keytype"
|
||||
let pubkey = key.pubkey.rsakey
|
||||
let nArray = @(getArray(pubkey.buffer, pubkey.key.n, pubkey.key.nlen))
|
||||
let eArray = @(getArray(pubkey.buffer, pubkey.key.e, pubkey.key.elen))
|
||||
|
||||
proc getResponseBody*(
|
||||
response: HttpClientResponseRef
|
||||
): Future[JsonNode] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
try:
|
||||
let bodyBytes = await response.getBodyBytes()
|
||||
if bodyBytes.len > 0:
|
||||
return bytesToString(bodyBytes).parseJson()
|
||||
return %*{} # empty body
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise
|
||||
newException(ACMEError, "Unexpected error occurred while getting body bytes", exc)
|
||||
except Exception as exc: # this is required for nim 1.6
|
||||
raise
|
||||
newException(ACMEError, "Unexpected error occurred while getting body bytes", exc)
|
||||
let n = base64UrlEncode(nArray)
|
||||
let e = base64UrlEncode(eArray)
|
||||
let keyJson = %*{"e": e, "kty": "RSA", "n": n}
|
||||
let digest = sha256.digest($keyJson)
|
||||
return base64UrlEncode(@(digest.data))
|
||||
|
||||
proc createCSR*(domain: string): string {.raises: [ACMEError].} =
|
||||
var certKey: cert_key_t
|
||||
var certCtx: cert_context_t
|
||||
var derCSR: ptr cert_buffer = nil
|
||||
proc getResponseBody*(
|
||||
response: HttpClientResponseRef
|
||||
): Future[JsonNode] {.async: (raises: [ACMEError, CancelledError]).} =
|
||||
try:
|
||||
let bodyBytes = await response.getBodyBytes()
|
||||
if bodyBytes.len > 0:
|
||||
return bytesToString(bodyBytes).parseJson()
|
||||
return %*{} # empty body
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise newException(
|
||||
ACMEError, "Unexpected error occurred while getting body bytes", exc
|
||||
)
|
||||
except Exception as exc: # this is required for nim 1.6
|
||||
raise newException(
|
||||
ACMEError, "Unexpected error occurred while getting body bytes", exc
|
||||
)
|
||||
|
||||
let personalizationStr = "libp2p_autotls"
|
||||
if cert_init_drbg(
|
||||
personalizationStr.cstring, personalizationStr.len.csize_t, certCtx.addr
|
||||
) != CERT_SUCCESS:
|
||||
raise newException(ACMEError, "Failed to initialize certCtx")
|
||||
if cert_generate_key(certCtx, certKey.addr) != CERT_SUCCESS:
|
||||
raise newException(ACMEError, "Failed to generate cert key")
|
||||
proc createCSR*(
|
||||
domain: string, certKeyPair: KeyPair
|
||||
): string {.raises: [ACMEError].} =
|
||||
var certKey: cert_key_t
|
||||
var certCtx: cert_context_t
|
||||
var derCSR: ptr cert_buffer = nil
|
||||
|
||||
if cert_signing_req(domain.cstring, certKey, derCSR.addr) != CERT_SUCCESS:
|
||||
raise newException(ACMEError, "Failed to create CSR")
|
||||
# convert KeyPair to cert_key_t
|
||||
let rawSeckey: seq[byte] = certKeyPair.seckey.getRawBytes.valueOr:
|
||||
raise newException(ACMEError, "Failed to get seckey raw bytes (DER)")
|
||||
let seckeyBuffer = rawSeckey.toCertBuffer()
|
||||
if cert_new_key_t(seckeyBuffer.unsafeAddr, certKey.addr) != CERT_SUCCESS:
|
||||
raise newException(ACMEError, "Failed to convert key pair to cert_key_t")
|
||||
|
||||
base64.encode(derCSR.toSeq, safe = true)
|
||||
# create CSR
|
||||
if cert_signing_req(domain.cstring, certKey, derCSR.addr) != CERT_SUCCESS:
|
||||
raise newException(ACMEError, "Failed to create CSR")
|
||||
|
||||
base64.encode(derCSR.toSeq, safe = true)
|
||||
|
||||
33
libp2p/autotls/mockservice.nim
Normal file
33
libp2p/autotls/mockservice.nim
Normal file
@@ -0,0 +1,33 @@
|
||||
when defined(libp2p_autotls_support):
|
||||
import ./service, ./acme/client, ../peeridauth/client
|
||||
|
||||
import ../crypto/crypto, ../crypto/rsa, websock/websock
|
||||
|
||||
type MockAutotlsService* = ref object of AutotlsService
|
||||
mockedCert*: TLSCertificate
|
||||
mockedKey*: TLSPrivateKey
|
||||
|
||||
proc new*(
|
||||
T: typedesc[MockAutotlsService],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
config: AutotlsConfig = AutotlsConfig.new(),
|
||||
): T =
|
||||
T(
|
||||
acmeClient:
|
||||
ACMEClient.new(api = ACMEApi.new(acmeServerURL = config.acmeServerURL)),
|
||||
brokerClient: PeerIDAuthClient.new(),
|
||||
bearer: Opt.none(BearerToken),
|
||||
cert: Opt.none(AutotlsCert),
|
||||
certReady: newAsyncEvent(),
|
||||
running: newAsyncEvent(),
|
||||
config: config,
|
||||
rng: rng,
|
||||
)
|
||||
|
||||
method getCertWhenReady*(
|
||||
self: MockAutotlsService
|
||||
): Future[AutotlsCert] {.async: (raises: [AutoTLSError, CancelledError]).} =
|
||||
AutotlsCert.new(self.mockedCert, self.mockedKey, Moment.now)
|
||||
|
||||
method setup*(self: MockAutotlsService) {.base, async.} =
|
||||
self.running.fire()
|
||||
@@ -10,19 +10,17 @@
|
||||
{.push raises: [].}
|
||||
{.push public.}
|
||||
|
||||
import net, results, json, sequtils
|
||||
|
||||
import chronos/apps/http/httpclient, chronos, chronicles, bearssl/rand
|
||||
import chronos, chronicles, net, results
|
||||
import chronos/apps/http/httpclient, bearssl/rand
|
||||
|
||||
import
|
||||
./acme/client,
|
||||
./utils,
|
||||
../crypto/crypto,
|
||||
../nameresolving/dnsresolver,
|
||||
../nameresolving/nameresolver,
|
||||
../peeridauth/client,
|
||||
../peerinfo,
|
||||
../switch,
|
||||
../utils/heartbeat,
|
||||
../peerinfo,
|
||||
../wire
|
||||
|
||||
logScope:
|
||||
@@ -40,6 +38,9 @@ const
|
||||
DefaultRenewCheckTime* = 1.hours
|
||||
DefaultRenewBufferTime = 1.hours
|
||||
|
||||
DefaultIssueRetries = 3
|
||||
DefaultIssueRetryTime = 1.seconds
|
||||
|
||||
AutoTLSBroker* = "registration.libp2p.direct"
|
||||
AutoTLSDNSServer* = "libp2p.direct"
|
||||
HttpOk* = 200
|
||||
@@ -53,174 +54,242 @@ type SigParam = object
|
||||
|
||||
type AutotlsCert* = ref object
|
||||
cert*: TLSCertificate
|
||||
privkey*: TLSPrivateKey
|
||||
expiry*: Moment
|
||||
|
||||
type AutotlsConfig* = ref object
|
||||
acmeServerURL*: Uri
|
||||
dnsResolver*: DnsResolver
|
||||
nameResolver*: NameResolver
|
||||
ipAddress: Opt[IpAddress]
|
||||
renewCheckTime*: Duration
|
||||
renewBufferTime*: Duration
|
||||
issueRetries*: int
|
||||
issueRetryTime*: Duration
|
||||
|
||||
type AutotlsService* = ref object of Service
|
||||
acmeClient: ACMEClient
|
||||
acmeClient*: ACMEClient
|
||||
brokerClient*: PeerIDAuthClient
|
||||
bearer*: Opt[BearerToken]
|
||||
brokerClient: PeerIDAuthClient
|
||||
cert*: Opt[AutotlsCert]
|
||||
certReady*: AsyncEvent
|
||||
config: AutotlsConfig
|
||||
running*: AsyncEvent
|
||||
config*: AutotlsConfig
|
||||
managerFut: Future[void]
|
||||
peerInfo: PeerInfo
|
||||
rng: ref HmacDrbgContext
|
||||
rng*: ref HmacDrbgContext
|
||||
|
||||
proc new*(T: typedesc[AutotlsCert], cert: TLSCertificate, expiry: Moment): T =
|
||||
T(cert: cert, expiry: expiry)
|
||||
when defined(libp2p_autotls_support):
|
||||
import json, sequtils, bearssl/pem
|
||||
|
||||
proc getCertWhenReady*(
|
||||
self: AutotlsService
|
||||
): Future[TLSCertificate] {.async: (raises: [AutoTLSError, CancelledError]).} =
|
||||
await self.certReady.wait()
|
||||
return self.cert.get.cert
|
||||
import
|
||||
../crypto/rsa,
|
||||
../utils/heartbeat,
|
||||
../transports/transport,
|
||||
../utils/ipaddr,
|
||||
../transports/tcptransport,
|
||||
../nameresolving/dnsresolver
|
||||
|
||||
proc new*(
|
||||
T: typedesc[AutotlsConfig],
|
||||
ipAddress: Opt[IpAddress] = NoneIp,
|
||||
nameServers: seq[TransportAddress] = DefaultDnsServers,
|
||||
acmeServerURL: Uri = parseUri(LetsEncryptURL),
|
||||
renewCheckTime: Duration = DefaultRenewCheckTime,
|
||||
renewBufferTime: Duration = DefaultRenewBufferTime,
|
||||
): T =
|
||||
T(
|
||||
dnsResolver: DnsResolver.new(nameServers),
|
||||
acmeServerURL: acmeServerURL,
|
||||
ipAddress: ipAddress,
|
||||
renewCheckTime: renewCheckTime,
|
||||
renewBufferTime: renewBufferTime,
|
||||
)
|
||||
proc new*(
|
||||
T: typedesc[AutotlsCert],
|
||||
cert: TLSCertificate,
|
||||
privkey: TLSPrivateKey,
|
||||
expiry: Moment,
|
||||
): T =
|
||||
T(cert: cert, privkey: privkey, expiry: expiry)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[AutotlsService],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
config: AutotlsConfig = AutotlsConfig.new(),
|
||||
): T =
|
||||
T(
|
||||
acmeClient: ACMEClient.new(api = ACMEApi.new(acmeServerURL = config.acmeServerURL)),
|
||||
brokerClient: PeerIDAuthClient.new(),
|
||||
bearer: Opt.none(BearerToken),
|
||||
cert: Opt.none(AutotlsCert),
|
||||
certReady: newAsyncEvent(),
|
||||
config: config,
|
||||
managerFut: nil,
|
||||
peerInfo: nil,
|
||||
rng: rng,
|
||||
)
|
||||
method getCertWhenReady*(
|
||||
self: AutotlsService
|
||||
): Future[AutotlsCert] {.base, async: (raises: [AutoTLSError, CancelledError]).} =
|
||||
await self.certReady.wait()
|
||||
return self.cert.get
|
||||
|
||||
method setup*(
|
||||
self: AutotlsService, switch: Switch
|
||||
): Future[bool] {.async: (raises: [CancelledError]).} =
|
||||
trace "Setting up AutotlsService"
|
||||
let hasBeenSetup = await procCall Service(self).setup(switch)
|
||||
if hasBeenSetup:
|
||||
self.peerInfo = switch.peerInfo
|
||||
if self.config.ipAddress.isNone():
|
||||
try:
|
||||
self.config.ipAddress = Opt.some(getPublicIPAddress())
|
||||
except AutoTLSError as exc:
|
||||
error "Failed to get public IP address", err = exc.msg
|
||||
return false
|
||||
self.managerFut = self.run(switch)
|
||||
return hasBeenSetup
|
||||
|
||||
method issueCertificate(
|
||||
self: AutotlsService
|
||||
) {.base, async: (raises: [AutoTLSError, ACMEError, PeerIDAuthError, CancelledError]).} =
|
||||
trace "Issuing certificate"
|
||||
|
||||
assert not self.peerInfo.isNil(), "Cannot issue new certificate: peerInfo not set"
|
||||
|
||||
# generate autotls domain string: "*.{peerID}.libp2p.direct"
|
||||
let baseDomain =
|
||||
api.Domain(encodePeerId(self.peerInfo.peerId) & "." & AutoTLSDNSServer)
|
||||
let domain = api.Domain("*." & baseDomain)
|
||||
|
||||
let acmeClient = self.acmeClient
|
||||
|
||||
trace "Requesting ACME challenge"
|
||||
let dns01Challenge = await acmeClient.getChallenge(@[domain])
|
||||
let keyAuth = acmeClient.genKeyAuthorization(dns01Challenge.dns01.token)
|
||||
let strMultiaddresses: seq[string] = self.peerInfo.addrs.mapIt($it)
|
||||
let payload = %*{"value": keyAuth, "addresses": strMultiaddresses}
|
||||
let registrationURL = parseUri("https://" & AutoTLSBroker & "/v1/_acme-challenge")
|
||||
|
||||
trace "Sending challenge to AutoTLS broker"
|
||||
let (bearer, response) =
|
||||
await self.brokerClient.send(registrationURL, self.peerInfo, payload, self.bearer)
|
||||
if self.bearer.isNone():
|
||||
# save bearer token for future
|
||||
self.bearer = Opt.some(bearer)
|
||||
if response.status != HttpOk:
|
||||
raise newException(
|
||||
AutoTLSError, "Failed to authenticate with AutoTLS Broker at " & AutoTLSBroker
|
||||
proc new*(
|
||||
T: typedesc[AutotlsConfig],
|
||||
ipAddress: Opt[IpAddress] = NoneIp,
|
||||
nameServers: seq[TransportAddress] = DefaultDnsServers,
|
||||
acmeServerURL: Uri = parseUri(LetsEncryptURL),
|
||||
renewCheckTime: Duration = DefaultRenewCheckTime,
|
||||
renewBufferTime: Duration = DefaultRenewBufferTime,
|
||||
issueRetries: int = DefaultIssueRetries,
|
||||
issueRetryTime: Duration = DefaultIssueRetryTime,
|
||||
): T =
|
||||
T(
|
||||
nameResolver: DnsResolver.new(nameServers),
|
||||
acmeServerURL: acmeServerURL,
|
||||
ipAddress: ipAddress,
|
||||
renewCheckTime: renewCheckTime,
|
||||
renewBufferTime: renewBufferTime,
|
||||
issueRetries: issueRetries,
|
||||
issueRetryTime: issueRetryTime,
|
||||
)
|
||||
|
||||
debug "Waiting for DNS record to be set"
|
||||
let dnsSet = await checkDNSRecords(
|
||||
self.config.dnsResolver, self.config.ipAddress.get(), baseDomain, keyAuth
|
||||
)
|
||||
if not dnsSet:
|
||||
raise newException(AutoTLSError, "DNS records not set")
|
||||
proc new*(
|
||||
T: typedesc[AutotlsService],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
config: AutotlsConfig = AutotlsConfig.new(),
|
||||
): T =
|
||||
T(
|
||||
acmeClient:
|
||||
ACMEClient.new(api = ACMEApi.new(acmeServerURL = config.acmeServerURL)),
|
||||
brokerClient: PeerIDAuthClient.new(),
|
||||
bearer: Opt.none(BearerToken),
|
||||
cert: Opt.none(AutotlsCert),
|
||||
certReady: newAsyncEvent(),
|
||||
running: newAsyncEvent(),
|
||||
config: config,
|
||||
managerFut: nil,
|
||||
peerInfo: nil,
|
||||
rng: rng,
|
||||
)
|
||||
|
||||
debug "Notifying challenge completion to ACME and downloading cert"
|
||||
let certResponse = await acmeClient.getCertificate(domain, dns01Challenge)
|
||||
method setup*(
|
||||
self: AutotlsService, switch: Switch
|
||||
): Future[bool] {.async: (raises: [CancelledError]).} =
|
||||
trace "Setting up AutotlsService"
|
||||
let hasBeenSetup = await procCall Service(self).setup(switch)
|
||||
if hasBeenSetup:
|
||||
if self.config.ipAddress.isNone():
|
||||
try:
|
||||
self.config.ipAddress = Opt.some(getPublicIPAddress())
|
||||
except ValueError as exc:
|
||||
error "Failed to get public IP address", err = exc.msg
|
||||
return false
|
||||
except OSError as exc:
|
||||
error "Failed to get public IP address", err = exc.msg
|
||||
return false
|
||||
self.managerFut = self.run(switch)
|
||||
return hasBeenSetup
|
||||
|
||||
debug "Installing certificate"
|
||||
let newCert =
|
||||
try:
|
||||
AutotlsCert.new(
|
||||
TLSCertificate.init(certResponse.rawCertificate),
|
||||
asMoment(certResponse.certificateExpiry),
|
||||
)
|
||||
except TLSStreamProtocolError:
|
||||
raise newException(AutoTLSError, "Could not parse downloaded certificates")
|
||||
self.cert = Opt.some(newCert)
|
||||
self.certReady.fire()
|
||||
debug "Certificate installed"
|
||||
method issueCertificate(
|
||||
self: AutotlsService
|
||||
): Future[bool] {.
|
||||
base, async: (raises: [AutoTLSError, ACMEError, PeerIDAuthError, CancelledError])
|
||||
.} =
|
||||
trace "Issuing certificate"
|
||||
|
||||
method run*(
|
||||
self: AutotlsService, switch: Switch
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "Certificate Management", self.config.renewCheckTime:
|
||||
if self.cert.isNone():
|
||||
if self.peerInfo.isNil():
|
||||
error "Cannot issue new certificate: peerInfo not set"
|
||||
return false
|
||||
|
||||
# generate autotls domain string: "*.{peerID}.libp2p.direct"
|
||||
let baseDomain =
|
||||
api.Domain(encodePeerId(self.peerInfo.peerId) & "." & AutoTLSDNSServer)
|
||||
let domain = api.Domain("*." & baseDomain)
|
||||
|
||||
let acmeClient = self.acmeClient
|
||||
|
||||
trace "Requesting ACME challenge"
|
||||
let dns01Challenge = await acmeClient.getChallenge(@[domain])
|
||||
trace "Generating key authorization"
|
||||
let keyAuth = acmeClient.genKeyAuthorization(dns01Challenge.dns01.token)
|
||||
|
||||
let addrs = await self.peerInfo.expandAddrs()
|
||||
if addrs.len == 0:
|
||||
error "Unable to authenticate with broker: no addresses"
|
||||
return false
|
||||
|
||||
let strMultiaddresses: seq[string] = addrs.mapIt($it)
|
||||
let payload = %*{"value": keyAuth, "addresses": strMultiaddresses}
|
||||
let registrationURL = parseUri("https://" & AutoTLSBroker & "/v1/_acme-challenge")
|
||||
|
||||
trace "Sending challenge to AutoTLS broker"
|
||||
let (bearer, response) =
|
||||
await self.brokerClient.send(registrationURL, self.peerInfo, payload, self.bearer)
|
||||
if self.bearer.isNone():
|
||||
# save bearer token for future
|
||||
self.bearer = Opt.some(bearer)
|
||||
if response.status != HttpOk:
|
||||
error "Failed to authenticate with AutoTLS Broker at " & AutoTLSBroker
|
||||
debug "Broker message",
|
||||
body = bytesToString(response.body), peerinfo = self.peerInfo
|
||||
return false
|
||||
|
||||
let dashedIpAddr = ($self.config.ipAddress.get()).replace(".", "-")
|
||||
let acmeChalDomain = api.Domain("_acme-challenge." & baseDomain)
|
||||
let ip4Domain = api.Domain(dashedIpAddr & "." & baseDomain)
|
||||
debug "Waiting for DNS record to be set", ip = ip4Domain, acme = acmeChalDomain
|
||||
let dnsSet = await checkDNSRecords(
|
||||
self.config.nameResolver, self.config.ipAddress.get(), baseDomain, keyAuth
|
||||
)
|
||||
if not dnsSet:
|
||||
error "DNS records not set"
|
||||
return false
|
||||
|
||||
trace "Notifying challenge completion to ACME and downloading cert"
|
||||
let certKeyPair = KeyPair.random(PKScheme.RSA, self.rng[]).get()
|
||||
|
||||
let certificate =
|
||||
await acmeClient.getCertificate(domain, certKeyPair, dns01Challenge)
|
||||
|
||||
let derPrivKey = certKeyPair.seckey.rsakey.getBytes.valueOr:
|
||||
raise newException(AutoTLSError, "Unable to get TLS private key")
|
||||
let pemPrivKey: string = derPrivKey.pemEncode("PRIVATE KEY")
|
||||
debug "autotls cert", pemPrivKey = pemPrivKey, cert = certificate.rawCertificate
|
||||
|
||||
trace "Installing certificate"
|
||||
let newCert =
|
||||
try:
|
||||
await self.issueCertificate()
|
||||
AutotlsCert.new(
|
||||
TLSCertificate.init(certificate.rawCertificate),
|
||||
TLSPrivateKey.init(pemPrivKey),
|
||||
asMoment(certificate.certificateExpiry),
|
||||
)
|
||||
except TLSStreamProtocolError:
|
||||
error "Could not parse downloaded certificates"
|
||||
return false
|
||||
self.cert = Opt.some(newCert)
|
||||
self.certReady.fire()
|
||||
trace "Certificate installed"
|
||||
true
|
||||
|
||||
proc hasTcpStarted(switch: Switch): bool =
|
||||
switch.transports.filterIt(it of TcpTransport and it.running).len == 0
|
||||
|
||||
proc tryIssueCertificate(self: AutotlsService) {.async: (raises: [CancelledError]).} =
|
||||
for _ in 0 ..< self.config.issueRetries:
|
||||
try:
|
||||
if await self.issueCertificate():
|
||||
return
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
error "Failed to issue certificate", err = exc.msg
|
||||
break
|
||||
await sleepAsync(self.config.issueRetryTime)
|
||||
error "Failed to issue certificate"
|
||||
|
||||
# AutotlsService will renew the cert 1h before it expires
|
||||
let cert = self.cert.get
|
||||
let waitTime = cert.expiry - Moment.now - self.config.renewBufferTime
|
||||
if waitTime <= self.config.renewBufferTime:
|
||||
try:
|
||||
await self.issueCertificate()
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
error "Failed to renew certificate", err = exc.msg
|
||||
break
|
||||
method run*(
|
||||
self: AutotlsService, switch: Switch
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
trace "Starting Autotls management"
|
||||
self.running.fire()
|
||||
self.peerInfo = switch.peerInfo
|
||||
|
||||
method stop*(
|
||||
self: AutotlsService, switch: Switch
|
||||
): Future[bool] {.async: (raises: [CancelledError]).} =
|
||||
let hasBeenStopped = await procCall Service(self).stop(switch)
|
||||
if hasBeenStopped:
|
||||
if not self.acmeClient.isNil():
|
||||
await self.acmeClient.close()
|
||||
if not self.brokerClient.isNil():
|
||||
await self.brokerClient.close()
|
||||
if not self.managerFut.isNil():
|
||||
await self.managerFut.cancelAndWait()
|
||||
self.managerFut = nil
|
||||
return hasBeenStopped
|
||||
# ensure that there's at least one TcpTransport running
|
||||
# for communicating with autotls broker
|
||||
if switch.hasTcpStarted():
|
||||
error "Could not find a running TcpTransport in switch"
|
||||
return
|
||||
|
||||
heartbeat "Certificate Management", self.config.renewCheckTime:
|
||||
if self.cert.isNone():
|
||||
await self.tryIssueCertificate()
|
||||
|
||||
# AutotlsService will renew the cert 1h before it expires
|
||||
let cert = self.cert.get
|
||||
let waitTime = cert.expiry - Moment.now - self.config.renewBufferTime
|
||||
if waitTime <= self.config.renewBufferTime:
|
||||
await self.tryIssueCertificate()
|
||||
|
||||
method stop*(
|
||||
self: AutotlsService, switch: Switch
|
||||
): Future[bool] {.async: (raises: [CancelledError]).} =
|
||||
let hasBeenStopped = await procCall Service(self).stop(switch)
|
||||
if hasBeenStopped:
|
||||
if not self.acmeClient.isNil():
|
||||
await self.acmeClient.close()
|
||||
if not self.brokerClient.isNil():
|
||||
await self.brokerClient.close()
|
||||
if not self.managerFut.isNil():
|
||||
await self.managerFut.cancelAndWait()
|
||||
self.managerFut = nil
|
||||
return hasBeenStopped
|
||||
|
||||
@@ -6,104 +6,77 @@
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
{.push public.}
|
||||
|
||||
import net, strutils
|
||||
from times import DateTime, toTime, toUnix
|
||||
import chronos, chronicles
|
||||
import ../errors
|
||||
|
||||
import chronos, stew/base36, chronicles
|
||||
|
||||
import
|
||||
./acme/client,
|
||||
../errors,
|
||||
../peerid,
|
||||
../multihash,
|
||||
../cid,
|
||||
../multicodec,
|
||||
../nameresolving/dnsresolver
|
||||
logScope:
|
||||
topics = "libp2p utils"
|
||||
|
||||
const
|
||||
DefaultDnsRetries = 10
|
||||
DefaultDnsRetries = 3
|
||||
DefaultDnsRetryTime = 1.seconds
|
||||
|
||||
type AutoTLSError* = object of LPError
|
||||
|
||||
proc checkedGetPrimaryIPAddr*(): IpAddress {.raises: [AutoTLSError].} =
|
||||
# This is so that we don't need to catch Exceptions directly
|
||||
# since we support 1.6.16 and getPrimaryIPAddr before nim 2 didn't have explicit .raises. pragmas
|
||||
try:
|
||||
return getPrimaryIPAddr()
|
||||
except Exception as exc:
|
||||
raise newException(AutoTLSError, "Error while getting primary IP address", exc)
|
||||
when defined(libp2p_autotls_support):
|
||||
import strutils
|
||||
from times import DateTime, toTime, toUnix
|
||||
import stew/base36
|
||||
import
|
||||
../peerid,
|
||||
../multihash,
|
||||
../cid,
|
||||
../multicodec,
|
||||
../nameresolving/nameresolver,
|
||||
./acme/client
|
||||
|
||||
proc isIPv4*(ip: IpAddress): bool =
|
||||
ip.family == IpAddressFamily.IPv4
|
||||
proc asMoment*(dt: DateTime): Moment =
|
||||
let unixTime: int64 = dt.toTime.toUnix
|
||||
return Moment.init(unixTime, Second)
|
||||
|
||||
proc isPublic*(ip: IpAddress): bool {.raises: [AutoTLSError].} =
|
||||
let ip = $ip
|
||||
try:
|
||||
not (
|
||||
ip.startsWith("10.") or
|
||||
(ip.startsWith("172.") and parseInt(ip.split(".")[1]) in 16 .. 31) or
|
||||
ip.startsWith("192.168.") or ip.startsWith("127.") or ip.startsWith("169.254.")
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise newException(AutoTLSError, "Failed to parse IP address", exc)
|
||||
proc encodePeerId*(peerId: PeerId): string {.raises: [AutoTLSError].} =
|
||||
var mh: MultiHash
|
||||
let decodeResult = MultiHash.decode(peerId.data, mh)
|
||||
if decodeResult.isErr() or decodeResult.get() == -1:
|
||||
raise
|
||||
newException(AutoTLSError, "Failed to decode PeerId: invalid multihash format")
|
||||
|
||||
proc getPublicIPAddress*(): IpAddress {.raises: [AutoTLSError].} =
|
||||
let ip = checkedGetPrimaryIPAddr()
|
||||
if not ip.isIPv4():
|
||||
raise newException(AutoTLSError, "Host does not have an IPv4 address")
|
||||
if not ip.isPublic():
|
||||
raise newException(AutoTLSError, "Host does not have a public IPv4 address")
|
||||
return ip
|
||||
let cidResult = Cid.init(CIDv1, multiCodec("libp2p-key"), mh)
|
||||
if cidResult.isErr():
|
||||
raise newException(AutoTLSError, "Failed to initialize CID from multihash")
|
||||
|
||||
proc asMoment*(dt: DateTime): Moment =
|
||||
let unixTime: int64 = dt.toTime.toUnix
|
||||
return Moment.init(unixTime, Second)
|
||||
return Base36.encode(cidResult.get().data.buffer)
|
||||
|
||||
proc encodePeerId*(peerId: PeerId): string {.raises: [AutoTLSError].} =
|
||||
var mh: MultiHash
|
||||
let decodeResult = MultiHash.decode(peerId.data, mh)
|
||||
if decodeResult.isErr() or decodeResult.get() == -1:
|
||||
raise
|
||||
newException(AutoTLSError, "Failed to decode PeerId: invalid multihash format")
|
||||
proc checkDNSRecords*(
|
||||
nameResolver: NameResolver,
|
||||
ipAddress: IpAddress,
|
||||
baseDomain: api.Domain,
|
||||
keyAuth: KeyAuthorization,
|
||||
retries: int = DefaultDnsRetries,
|
||||
): Future[bool] {.async: (raises: [AutoTLSError, CancelledError]).} =
|
||||
# if my ip address is 100.10.10.3 then the ip4Domain will be:
|
||||
# 100-10-10-3.{peerIdBase36}.libp2p.direct
|
||||
# and acme challenge TXT domain will be:
|
||||
# _acme-challenge.{peerIdBase36}.libp2p.direct
|
||||
let dashedIpAddr = ($ipAddress).replace(".", "-")
|
||||
let acmeChalDomain = api.Domain("_acme-challenge." & baseDomain)
|
||||
let ip4Domain = api.Domain(dashedIpAddr & "." & baseDomain)
|
||||
|
||||
let cidResult = Cid.init(CIDv1, multiCodec("libp2p-key"), mh)
|
||||
if cidResult.isErr():
|
||||
raise newException(AutoTLSError, "Failed to initialize CID from multihash")
|
||||
var txt: seq[string]
|
||||
var ip4: seq[TransportAddress]
|
||||
for _ in 0 .. retries:
|
||||
txt = await nameResolver.resolveTxt(acmeChalDomain)
|
||||
try:
|
||||
ip4 = await nameResolver.resolveIp(ip4Domain, 0.Port)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
error "Failed to resolve IP", description = exc.msg # retry
|
||||
if txt.len > 0 and txt[0] == keyAuth and ip4.len > 0:
|
||||
return true
|
||||
await sleepAsync(DefaultDnsRetryTime)
|
||||
|
||||
return Base36.encode(cidResult.get().data.buffer)
|
||||
|
||||
proc checkDNSRecords*(
|
||||
dnsResolver: DnsResolver,
|
||||
ipAddress: IpAddress,
|
||||
baseDomain: api.Domain,
|
||||
keyAuth: KeyAuthorization,
|
||||
retries: int = DefaultDnsRetries,
|
||||
): Future[bool] {.async: (raises: [AutoTLSError, CancelledError]).} =
|
||||
# if my ip address is 100.10.10.3 then the ip4Domain will be:
|
||||
# 100-10-10-3.{peerIdBase36}.libp2p.direct
|
||||
# and acme challenge TXT domain will be:
|
||||
# _acme-challenge.{peerIdBase36}.libp2p.direct
|
||||
let dashedIpAddr = ($ipAddress).replace(".", "-")
|
||||
let acmeChalDomain = api.Domain("_acme-challenge." & baseDomain)
|
||||
let ip4Domain = api.Domain(dashedIpAddr & "." & baseDomain)
|
||||
|
||||
var txt: seq[string]
|
||||
var ip4: seq[TransportAddress]
|
||||
for _ in 0 .. retries:
|
||||
txt = await dnsResolver.resolveTxt(acmeChalDomain)
|
||||
try:
|
||||
ip4 = await dnsResolver.resolveIp(ip4Domain, 0.Port)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
error "Failed to resolve IP", description = exc.msg # retry
|
||||
if txt.len > 0 and txt[0] == keyAuth and ip4.len > 0:
|
||||
return true
|
||||
await sleepAsync(DefaultDnsRetryTime)
|
||||
|
||||
return false
|
||||
return false
|
||||
|
||||
@@ -15,7 +15,7 @@ runnableExamples:
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import options, tables, chronos, chronicles, sequtils, uri
|
||||
import options, tables, chronos, chronicles, sequtils
|
||||
import
|
||||
switch,
|
||||
peerid,
|
||||
@@ -26,7 +26,8 @@ import
|
||||
transports/[transport, tcptransport, wstransport, memorytransport],
|
||||
muxers/[muxer, mplex/mplex, yamux/yamux],
|
||||
protocols/[identify, secure/secure, secure/noise, rendezvous],
|
||||
protocols/connectivity/[autonat/server, relay/relay, relay/client, relay/rtransport],
|
||||
protocols/connectivity/
|
||||
[autonat/server, autonatv2/server, relay/relay, relay/client, relay/rtransport],
|
||||
connmanager,
|
||||
upgrademngrs/muxedupgrade,
|
||||
observedaddrmanager,
|
||||
@@ -43,9 +44,16 @@ export
|
||||
const MemoryAutoAddress* = memorytransport.MemoryAutoAddress
|
||||
|
||||
type
|
||||
TransportProvider* {.public.} = proc(
|
||||
upgr: Upgrade, privateKey: PrivateKey, autotls: AutotlsService
|
||||
): Transport {.gcsafe, raises: [].}
|
||||
TransportProvider* {.deprecated: "Use TransportBuilder instead".} =
|
||||
proc(upgr: Upgrade, privateKey: PrivateKey): Transport {.gcsafe, raises: [].}
|
||||
|
||||
TransportBuilder* {.public.} =
|
||||
proc(config: TransportConfig): Transport {.gcsafe, raises: [].}
|
||||
|
||||
TransportConfig* = ref object
|
||||
upgr*: Upgrade
|
||||
privateKey*: PrivateKey
|
||||
autotls*: AutotlsService
|
||||
|
||||
SecureProtocol* {.pure.} = enum
|
||||
Noise
|
||||
@@ -55,7 +63,7 @@ type
|
||||
addresses: seq[MultiAddress]
|
||||
secureManagers: seq[SecureProtocol]
|
||||
muxers: seq[MuxerProvider]
|
||||
transports: seq[TransportProvider]
|
||||
transports: seq[TransportBuilder]
|
||||
rng: ref HmacDrbgContext
|
||||
maxConnections: int
|
||||
maxIn: int
|
||||
@@ -67,6 +75,8 @@ type
|
||||
nameResolver: NameResolver
|
||||
peerStoreCapacity: Opt[int]
|
||||
autonat: bool
|
||||
autonatV2: bool
|
||||
autonatV2Config: AutonatV2Config
|
||||
autotls: AutotlsService
|
||||
circuitRelay: Relay
|
||||
rdv: RendezVous
|
||||
@@ -152,28 +162,42 @@ proc withNoise*(b: SwitchBuilder): SwitchBuilder {.public.} =
|
||||
b
|
||||
|
||||
proc withTransport*(
|
||||
b: SwitchBuilder, prov: TransportProvider
|
||||
b: SwitchBuilder, prov: TransportBuilder
|
||||
): SwitchBuilder {.public.} =
|
||||
## Use a custom transport
|
||||
runnableExamples:
|
||||
let switch = SwitchBuilder
|
||||
.new()
|
||||
.withTransport(
|
||||
proc(
|
||||
upgr: Upgrade, privateKey: PrivateKey, autotls: AutotlsService
|
||||
): Transport =
|
||||
TcpTransport.new(flags, upgr)
|
||||
proc(config: TransportConfig): Transport =
|
||||
TcpTransport.new(flags, config.upgr)
|
||||
)
|
||||
.build()
|
||||
b.transports.add(prov)
|
||||
b
|
||||
|
||||
proc withTransport*(
|
||||
b: SwitchBuilder, prov: TransportProvider
|
||||
): SwitchBuilder {.deprecated: "Use TransportBuilder instead".} =
|
||||
## Use a custom transport
|
||||
runnableExamples:
|
||||
let switch = SwitchBuilder
|
||||
.new()
|
||||
.withTransport(
|
||||
proc(upgr: Upgrade, privateKey: PrivateKey): Transport =
|
||||
TcpTransport.new(flags, upgr)
|
||||
)
|
||||
.build()
|
||||
let tBuilder: TransportBuilder = proc(config: TransportConfig): Transport =
|
||||
prov(config.upgr, config.privateKey)
|
||||
b.withTransport(tBuilder)
|
||||
|
||||
proc withTcpTransport*(
|
||||
b: SwitchBuilder, flags: set[ServerFlags] = {}
|
||||
): SwitchBuilder {.public.} =
|
||||
b.withTransport(
|
||||
proc(upgr: Upgrade, privateKey: PrivateKey, autotls: AutotlsService): Transport =
|
||||
TcpTransport.new(flags, upgr)
|
||||
proc(config: TransportConfig): Transport =
|
||||
TcpTransport.new(flags, config.upgr)
|
||||
)
|
||||
|
||||
proc withWsTransport*(
|
||||
@@ -184,8 +208,10 @@ proc withWsTransport*(
|
||||
flags: set[ServerFlags] = {},
|
||||
): SwitchBuilder =
|
||||
b.withTransport(
|
||||
proc(upgr: Upgrade, privateKey: PrivateKey, autotls: AutotlsService): Transport =
|
||||
WsTransport.new(upgr, tlsPrivateKey, tlsCertificate, tlsFlags, flags)
|
||||
proc(config: TransportConfig): Transport =
|
||||
WsTransport.new(
|
||||
config.upgr, tlsPrivateKey, tlsCertificate, config.autotls, tlsFlags, flags
|
||||
)
|
||||
)
|
||||
|
||||
when defined(libp2p_quic_support):
|
||||
@@ -193,14 +219,14 @@ when defined(libp2p_quic_support):
|
||||
|
||||
proc withQuicTransport*(b: SwitchBuilder): SwitchBuilder {.public.} =
|
||||
b.withTransport(
|
||||
proc(upgr: Upgrade, privateKey: PrivateKey, autotls: AutotlsService): Transport =
|
||||
QuicTransport.new(upgr, privateKey)
|
||||
proc(config: TransportConfig): Transport =
|
||||
QuicTransport.new(config.upgr, config.privateKey)
|
||||
)
|
||||
|
||||
proc withMemoryTransport*(b: SwitchBuilder): SwitchBuilder {.public.} =
|
||||
b.withTransport(
|
||||
proc(upgr: Upgrade, privateKey: PrivateKey, autotls: AutotlsService): Transport =
|
||||
MemoryTransport.new(upgr)
|
||||
proc(config: TransportConfig): Transport =
|
||||
MemoryTransport.new(config.upgr)
|
||||
)
|
||||
|
||||
proc withRng*(b: SwitchBuilder, rng: ref HmacDrbgContext): SwitchBuilder {.public.} =
|
||||
@@ -257,12 +283,20 @@ proc withAutonat*(b: SwitchBuilder): SwitchBuilder =
|
||||
b.autonat = true
|
||||
b
|
||||
|
||||
proc withAutotls*(
|
||||
b: SwitchBuilder, config: AutotlsConfig = AutotlsConfig.new()
|
||||
): SwitchBuilder {.public.} =
|
||||
b.autotls = AutotlsService.new(config = config)
|
||||
proc withAutonatV2*(
|
||||
b: SwitchBuilder, config: AutonatV2Config = AutonatV2Config.new()
|
||||
): SwitchBuilder =
|
||||
b.autonatV2 = true
|
||||
b.autonatV2Config = config
|
||||
b
|
||||
|
||||
when defined(libp2p_autotls_support):
|
||||
proc withAutotls*(
|
||||
b: SwitchBuilder, config: AutotlsConfig = AutotlsConfig.new()
|
||||
): SwitchBuilder {.public.} =
|
||||
b.autotls = AutotlsService.new(config = config)
|
||||
b
|
||||
|
||||
proc withCircuitRelay*(b: SwitchBuilder, r: Relay = Relay.new()): SwitchBuilder =
|
||||
b.circuitRelay = r
|
||||
b
|
||||
@@ -320,7 +354,11 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
|
||||
let transports = block:
|
||||
var transports: seq[Transport]
|
||||
for tProvider in b.transports:
|
||||
transports.add(tProvider(muxedUpgrade, seckey, b.autotls))
|
||||
transports.add(
|
||||
tProvider(
|
||||
TransportConfig(upgr: muxedUpgrade, privateKey: seckey, autotls: b.autotls)
|
||||
)
|
||||
)
|
||||
transports
|
||||
|
||||
if b.secureManagers.len == 0:
|
||||
@@ -351,7 +389,10 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
|
||||
|
||||
switch.mount(identify)
|
||||
|
||||
if b.autonat:
|
||||
if b.autonatV2:
|
||||
let autonatV2 = AutonatV2.new(switch, config = b.autonatV2Config)
|
||||
switch.mount(autonatV2)
|
||||
elif b.autonat:
|
||||
let autonat = Autonat.new(switch)
|
||||
switch.mount(autonat)
|
||||
|
||||
@@ -367,13 +408,78 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
|
||||
|
||||
return switch
|
||||
|
||||
proc newStandardSwitch*(
|
||||
type TransportType* {.pure.} = enum
|
||||
QUIC
|
||||
TCP
|
||||
Memory
|
||||
|
||||
proc newStandardSwitchBuilder*(
|
||||
privKey = none(PrivateKey),
|
||||
addrs: MultiAddress | seq[MultiAddress] =
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
|
||||
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
|
||||
addrs: MultiAddress | seq[MultiAddress] = newSeq[MultiAddress](),
|
||||
transport: TransportType = TransportType.TCP,
|
||||
transportFlags: set[ServerFlags] = {},
|
||||
rng = newRng(),
|
||||
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
|
||||
inTimeout: Duration = 5.minutes,
|
||||
outTimeout: Duration = 5.minutes,
|
||||
maxConnections = MaxConnections,
|
||||
maxIn = -1,
|
||||
maxOut = -1,
|
||||
maxConnsPerPeer = MaxConnectionsPerPeer,
|
||||
nameResolver: NameResolver = nil,
|
||||
sendSignedPeerRecord = false,
|
||||
peerStoreCapacity = 1000,
|
||||
): SwitchBuilder {.raises: [LPError], public.} =
|
||||
## Helper for common switch configurations.
|
||||
var b = SwitchBuilder
|
||||
.new()
|
||||
.withRng(rng)
|
||||
.withSignedPeerRecord(sendSignedPeerRecord)
|
||||
.withMaxConnections(maxConnections)
|
||||
.withMaxIn(maxIn)
|
||||
.withMaxOut(maxOut)
|
||||
.withMaxConnsPerPeer(maxConnsPerPeer)
|
||||
.withPeerStore(capacity = peerStoreCapacity)
|
||||
.withNameResolver(nameResolver)
|
||||
.withNoise()
|
||||
|
||||
var addrs =
|
||||
when addrs is MultiAddress:
|
||||
@[addrs]
|
||||
else:
|
||||
addrs
|
||||
|
||||
case transport
|
||||
of TransportType.QUIC:
|
||||
when defined(libp2p_quic_support):
|
||||
if addrs.len == 0:
|
||||
addrs = @[MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet()]
|
||||
b = b.withQuicTransport().withAddresses(addrs)
|
||||
else:
|
||||
raiseAssert "QUIC not supported in this build"
|
||||
of TransportType.TCP:
|
||||
if addrs.len == 0:
|
||||
addrs = @[MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet()]
|
||||
b = b.withTcpTransport(transportFlags).withAddresses(addrs).withMplex(
|
||||
inTimeout, outTimeout
|
||||
)
|
||||
of TransportType.Memory:
|
||||
if addrs.len == 0:
|
||||
addrs = @[MultiAddress.init(MemoryAutoAddress).tryGet()]
|
||||
b = b.withMemoryTransport().withAddresses(addrs).withMplex(inTimeout, outTimeout)
|
||||
|
||||
privKey.withValue(pkey):
|
||||
b = b.withPrivateKey(pkey)
|
||||
|
||||
b
|
||||
|
||||
proc newStandardSwitch*(
|
||||
privKey = none(PrivateKey),
|
||||
addrs: MultiAddress | seq[MultiAddress] = newSeq[MultiAddress](),
|
||||
transport: TransportType = TransportType.TCP,
|
||||
transportFlags: set[ServerFlags] = {},
|
||||
rng = newRng(),
|
||||
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
|
||||
inTimeout: Duration = 5.minutes,
|
||||
outTimeout: Duration = 5.minutes,
|
||||
maxConnections = MaxConnections,
|
||||
@@ -384,28 +490,21 @@ proc newStandardSwitch*(
|
||||
sendSignedPeerRecord = false,
|
||||
peerStoreCapacity = 1000,
|
||||
): Switch {.raises: [LPError], public.} =
|
||||
## Helper for common switch configurations.
|
||||
let addrs =
|
||||
when addrs is MultiAddress:
|
||||
@[addrs]
|
||||
else:
|
||||
addrs
|
||||
var b = SwitchBuilder
|
||||
.new()
|
||||
.withAddresses(addrs)
|
||||
.withRng(rng)
|
||||
.withSignedPeerRecord(sendSignedPeerRecord)
|
||||
.withMaxConnections(maxConnections)
|
||||
.withMaxIn(maxIn)
|
||||
.withMaxOut(maxOut)
|
||||
.withMaxConnsPerPeer(maxConnsPerPeer)
|
||||
.withPeerStore(capacity = peerStoreCapacity)
|
||||
.withMplex(inTimeout, outTimeout)
|
||||
.withTcpTransport(transportFlags)
|
||||
.withNameResolver(nameResolver)
|
||||
.withNoise()
|
||||
|
||||
privKey.withValue(pkey):
|
||||
b = b.withPrivateKey(pkey)
|
||||
|
||||
b.build()
|
||||
newStandardSwitchBuilder(
|
||||
privKey = privKey,
|
||||
addrs = addrs,
|
||||
transport = transport,
|
||||
transportFlags = transportFlags,
|
||||
rng = rng,
|
||||
secureManagers = secureManagers,
|
||||
inTimeout = inTimeout,
|
||||
outTimeout = outTimeout,
|
||||
maxConnections = maxConnections,
|
||||
maxIn = maxIn,
|
||||
maxOut = maxOut,
|
||||
maxConnsPerPeer = maxConnsPerPeer,
|
||||
nameResolver = nameResolver,
|
||||
sendSignedPeerRecord = sendSignedPeerRecord,
|
||||
peerStoreCapacity = peerStoreCapacity,
|
||||
)
|
||||
.build()
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
import chronos
|
||||
import results
|
||||
import peerid, stream/connection, transports/transport
|
||||
import peerid, stream/connection, transports/transport, muxers/muxer
|
||||
|
||||
export results
|
||||
|
||||
@@ -65,6 +65,23 @@ method dial*(
|
||||
method addTransport*(self: Dial, transport: Transport) {.base.} =
|
||||
doAssert(false, "[Dial.addTransport] abstract method not implemented!")
|
||||
|
||||
method dialAndUpgrade*(
|
||||
self: Dial, peerId: Opt[PeerId], address: MultiAddress, dir = Direction.Out
|
||||
): Future[Muxer] {.base, async: (raises: [CancelledError]).} =
|
||||
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
|
||||
|
||||
method dialAndUpgrade*(
|
||||
self: Dial, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
|
||||
): Future[Muxer] {.
|
||||
base, async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
|
||||
.} =
|
||||
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
|
||||
|
||||
method negotiateStream*(
|
||||
self: Dial, conn: Connection, protos: seq[string]
|
||||
): Future[Connection] {.base, async: (raises: [CatchableError]).} =
|
||||
doAssert(false, "[Dial.negotiateStream] abstract method not implemented!")
|
||||
|
||||
method tryDial*(
|
||||
self: Dial, peerId: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[MultiAddress]] {.
|
||||
|
||||
@@ -43,7 +43,7 @@ type Dialer* = ref object of Dial
|
||||
peerStore: PeerStore
|
||||
nameResolver: NameResolver
|
||||
|
||||
proc dialAndUpgrade(
|
||||
proc dialAndUpgrade*(
|
||||
self: Dialer,
|
||||
peerId: Opt[PeerId],
|
||||
hostname: string,
|
||||
@@ -104,12 +104,13 @@ proc expandDnsAddr(
|
||||
): Future[seq[(MultiAddress, Opt[PeerId])]] {.
|
||||
async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
|
||||
.} =
|
||||
if not DNSADDR.matchPartial(address):
|
||||
if not DNS.matchPartial(address):
|
||||
return @[(address, peerId)]
|
||||
if isNil(self.nameResolver):
|
||||
info "Can't resolve DNSADDR without NameResolver", ma = address
|
||||
return @[]
|
||||
|
||||
trace "Start trying to resolve addresses"
|
||||
let
|
||||
toResolve =
|
||||
if peerId.isSome:
|
||||
@@ -121,6 +122,9 @@ proc expandDnsAddr(
|
||||
address
|
||||
resolved = await self.nameResolver.resolveDnsAddr(toResolve)
|
||||
|
||||
debug "resolved addresses",
|
||||
originalAddresses = toResolve, resolvedAddresses = resolved
|
||||
|
||||
for resolvedAddress in resolved:
|
||||
let lastPart = resolvedAddress[^1].tryGet()
|
||||
if lastPart.protoCode == Result[MultiCodec, string].ok(multiCodec("p2p")):
|
||||
@@ -135,7 +139,7 @@ proc expandDnsAddr(
|
||||
else:
|
||||
result.add((resolvedAddress, peerId))
|
||||
|
||||
proc dialAndUpgrade(
|
||||
proc dialAndUpgrade*(
|
||||
self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
|
||||
): Future[Muxer] {.
|
||||
async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
|
||||
@@ -145,7 +149,6 @@ proc dialAndUpgrade(
|
||||
for rawAddress in addrs:
|
||||
# resolve potential dnsaddr
|
||||
let addresses = await self.expandDnsAddr(peerId, rawAddress)
|
||||
|
||||
for (expandedAddress, addrPeerId) in addresses:
|
||||
# DNS resolution
|
||||
let
|
||||
@@ -156,6 +159,11 @@ proc dialAndUpgrade(
|
||||
else:
|
||||
await self.nameResolver.resolveMAddress(expandedAddress)
|
||||
|
||||
debug "Expanded address and hostname",
|
||||
expandedAddress = expandedAddress,
|
||||
hostname = hostname,
|
||||
resolvedAddresses = resolvedAddresses
|
||||
|
||||
for resolvedAddress in resolvedAddresses:
|
||||
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, dir)
|
||||
if not isNil(result):
|
||||
@@ -276,7 +284,7 @@ method connect*(
|
||||
return
|
||||
(await self.internalConnect(Opt.none(PeerId), @[address], false)).connection.peerId
|
||||
|
||||
proc negotiateStream(
|
||||
proc negotiateStream*(
|
||||
self: Dialer, conn: Connection, protos: seq[string]
|
||||
): Future[Connection] {.async: (raises: [CatchableError]).} =
|
||||
trace "Negotiating stream", conn, protos
|
||||
|
||||
@@ -159,7 +159,7 @@ proc stop*(query: DiscoveryQuery) =
|
||||
query.finished = true
|
||||
for r in query.futs:
|
||||
if not r.finished():
|
||||
r.cancel()
|
||||
r.cancelSoon()
|
||||
|
||||
proc stop*(dm: DiscoveryManager) =
|
||||
for q in dm.queries:
|
||||
@@ -167,7 +167,7 @@ proc stop*(dm: DiscoveryManager) =
|
||||
for i in dm.interfaces:
|
||||
if isNil(i.advertiseLoop):
|
||||
continue
|
||||
i.advertiseLoop.cancel()
|
||||
i.advertiseLoop.cancelSoon()
|
||||
|
||||
proc getPeer*(
|
||||
query: DiscoveryQuery
|
||||
@@ -179,7 +179,7 @@ proc getPeer*(
|
||||
try:
|
||||
await getter or allFinished(query.futs)
|
||||
except CancelledError as exc:
|
||||
getter.cancel()
|
||||
getter.cancelSoon()
|
||||
raise exc
|
||||
|
||||
if not finished(getter):
|
||||
|
||||
@@ -27,7 +27,7 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
|
||||
quote:
|
||||
for res in `futs`:
|
||||
if res.failed:
|
||||
let exc = res.readError()
|
||||
let exc = res.error
|
||||
# We still don't abort but warn
|
||||
debug "A future has failed, enable trace logging for details",
|
||||
error = exc.name
|
||||
@@ -37,7 +37,7 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
|
||||
for res in `futs`:
|
||||
block check:
|
||||
if res.failed:
|
||||
let exc = res.readError()
|
||||
let exc = res.error
|
||||
for i in 0 ..< `nexclude`:
|
||||
if exc of `exclude`[i]:
|
||||
trace "A future has failed", error = exc.name, description = exc.msg
|
||||
|
||||
@@ -843,6 +843,14 @@ proc init*(
|
||||
res.data.finish()
|
||||
ok(res)
|
||||
|
||||
proc getPart*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] =
|
||||
## Returns the first multiaddress in ``value`` with codec ``codec``
|
||||
for part in ma:
|
||||
let part = ?part
|
||||
if codec == ?part.protoCode:
|
||||
return ok(part)
|
||||
err("no such codec in multiaddress")
|
||||
|
||||
proc getProtocol(name: string): MAProtocol {.inline.} =
|
||||
let mc = MultiCodec.codec(name)
|
||||
if mc != InvalidMultiCodec:
|
||||
@@ -1119,3 +1127,32 @@ proc getRepeatedField*(
|
||||
err(ProtoError.IncorrectBlob)
|
||||
else:
|
||||
ok(true)
|
||||
|
||||
proc areAddrsConsistent*(a, b: MultiAddress): bool =
|
||||
## Checks if two multiaddresses have the same protocol stack.
|
||||
let protosA = a.protocols().get()
|
||||
let protosB = b.protocols().get()
|
||||
if protosA.len != protosB.len:
|
||||
return false
|
||||
|
||||
for idx in 0 ..< protosA.len:
|
||||
let protoA = protosA[idx]
|
||||
let protoB = protosB[idx]
|
||||
|
||||
if protoA != protoB:
|
||||
if idx == 0:
|
||||
# allow DNS ↔ IP at the first component
|
||||
if protoB == multiCodec("dns") or protoB == multiCodec("dnsaddr"):
|
||||
if not (protoA == multiCodec("ip4") or protoA == multiCodec("ip6")):
|
||||
return false
|
||||
elif protoB == multiCodec("dns4"):
|
||||
if protoA != multiCodec("ip4"):
|
||||
return false
|
||||
elif protoB == multiCodec("dns6"):
|
||||
if protoA != multiCodec("ip6"):
|
||||
return false
|
||||
else:
|
||||
return false
|
||||
else:
|
||||
return false
|
||||
true
|
||||
|
||||
@@ -249,11 +249,7 @@ proc addHandler*[E](
|
||||
m.handlers.add(HandlerHolder(protos: @[codec], protocol: protocol, match: matcher))
|
||||
|
||||
proc start*(m: MultistreamSelect) {.async: (raises: [CancelledError]).} =
|
||||
# Nim 1.6.18: Using `mapIt` results in a seq of `.Raising([])`
|
||||
# TODO https://github.com/nim-lang/Nim/issues/23445
|
||||
var futs = newSeqOfCap[Future[void].Raising([CancelledError])](m.handlers.len)
|
||||
for it in m.handlers:
|
||||
futs.add it.protocol.start()
|
||||
let futs = m.handlers.mapIt(it.protocol.start())
|
||||
try:
|
||||
await allFutures(futs)
|
||||
for fut in futs:
|
||||
@@ -273,10 +269,7 @@ proc start*(m: MultistreamSelect) {.async: (raises: [CancelledError]).} =
|
||||
raise exc
|
||||
|
||||
proc stop*(m: MultistreamSelect) {.async: (raises: []).} =
|
||||
# Nim 1.6.18: Using `mapIt` results in a seq of `.Raising([CancelledError])`
|
||||
var futs = newSeqOfCap[Future[void].Raising([])](m.handlers.len)
|
||||
for it in m.handlers:
|
||||
futs.add it.protocol.stop()
|
||||
let futs = m.handlers.mapIt(it.protocol.stop())
|
||||
await noCancel allFutures(futs)
|
||||
for fut in futs:
|
||||
await fut
|
||||
|
||||
@@ -150,6 +150,10 @@ method close*(s: LPChannel) {.async: (raises: []).} =
|
||||
|
||||
trace "Closed channel", s, len = s.len
|
||||
|
||||
method closeWrite*(s: LPChannel) {.async: (raises: []).} =
|
||||
## For mplex, closeWrite is the same as close - it implements half-close
|
||||
await s.close()
|
||||
|
||||
method initStream*(s: LPChannel) =
|
||||
if s.objName.len == 0:
|
||||
s.objName = LPChannelTrackerName
|
||||
|
||||
@@ -95,6 +95,7 @@ proc newStreamInternal*(
|
||||
|
||||
result.peerId = m.connection.peerId
|
||||
result.observedAddr = m.connection.observedAddr
|
||||
result.localAddr = m.connection.localAddr
|
||||
result.transportDir = m.connection.transportDir
|
||||
when defined(libp2p_agents_metrics):
|
||||
result.shortAgent = m.connection.shortAgent
|
||||
|
||||
@@ -54,6 +54,10 @@ method newStream*(
|
||||
.} =
|
||||
raiseAssert("[Muxer.newStream] abstract method not implemented!")
|
||||
|
||||
when defined(libp2p_agents_metrics):
|
||||
method setShortAgent*(m: Muxer, shortAgent: string) {.base, gcsafe.} =
|
||||
m.connection.shortAgent = shortAgent
|
||||
|
||||
method close*(m: Muxer) {.base, async: (raises: []).} =
|
||||
if m.connection != nil:
|
||||
await m.connection.close()
|
||||
|
||||
@@ -135,12 +135,11 @@ proc windowUpdate(
|
||||
)
|
||||
|
||||
type
|
||||
ToSend =
|
||||
tuple[
|
||||
data: seq[byte],
|
||||
sent: int,
|
||||
fut: Future[void].Raising([CancelledError, LPStreamError]),
|
||||
]
|
||||
ToSend = ref object
|
||||
data: seq[byte]
|
||||
sent: int
|
||||
fut: Future[void].Raising([CancelledError, LPStreamError])
|
||||
|
||||
YamuxChannel* = ref object of Connection
|
||||
id: uint32
|
||||
recvWindow: int
|
||||
@@ -218,6 +217,19 @@ method closeImpl*(channel: YamuxChannel) {.async: (raises: []).} =
|
||||
discard
|
||||
await channel.actuallyClose()
|
||||
|
||||
method closeWrite*(channel: YamuxChannel) {.async: (raises: []).} =
|
||||
## For yamux, closeWrite is the same as close - it implements half-close
|
||||
await channel.close()
|
||||
|
||||
proc clearQueues(channel: YamuxChannel, error: ref LPStreamEOFError = nil) =
|
||||
for toSend in channel.sendQueue:
|
||||
if error.isNil():
|
||||
toSend.fut.complete()
|
||||
else:
|
||||
toSend.fut.fail(error)
|
||||
channel.sendQueue = @[]
|
||||
channel.recvQueue.clear()
|
||||
|
||||
proc reset(channel: YamuxChannel, isLocal: bool = false) {.async: (raises: []).} =
|
||||
# If we reset locally, we want to flush up to a maximum of recvWindow
|
||||
# bytes. It's because the peer we're connected to can send us data before
|
||||
@@ -227,9 +239,8 @@ proc reset(channel: YamuxChannel, isLocal: bool = false) {.async: (raises: []).}
|
||||
trace "Reset channel"
|
||||
channel.isReset = true
|
||||
channel.remoteReset = not isLocal
|
||||
for (d, s, fut) in channel.sendQueue:
|
||||
fut.fail(newLPStreamEOFError())
|
||||
channel.sendQueue = @[]
|
||||
channel.clearQueues(newLPStreamEOFError())
|
||||
|
||||
channel.sendWindow = 0
|
||||
if not channel.closedLocally:
|
||||
if isLocal and not channel.isSending:
|
||||
@@ -278,6 +289,7 @@ method readOnce*(
|
||||
trace "stream is down when readOnce", channel = $channel
|
||||
newLPStreamConnDownError()
|
||||
if channel.isEof:
|
||||
channel.clearQueues()
|
||||
raise newLPStreamRemoteClosedError()
|
||||
if channel.recvQueue.isEmpty():
|
||||
channel.receivedData.clear()
|
||||
@@ -292,6 +304,7 @@ method readOnce*(
|
||||
await closedRemotelyFut or receivedDataFut
|
||||
if channel.closedRemotely.isSet() and channel.recvQueue.isEmpty():
|
||||
channel.isEof = true
|
||||
channel.clearQueues()
|
||||
return
|
||||
0 # we return 0 to indicate that the channel is closed for reading from now on
|
||||
|
||||
@@ -315,17 +328,18 @@ proc gotDataFromRemote(
|
||||
proc setMaxRecvWindow*(channel: YamuxChannel, maxRecvWindow: int) =
|
||||
channel.maxRecvWindow = maxRecvWindow
|
||||
|
||||
proc trySend(
|
||||
channel: YamuxChannel
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
proc sendLoop(channel: YamuxChannel) {.async: (raises: []).} =
|
||||
if channel.isSending:
|
||||
return
|
||||
channel.isSending = true
|
||||
defer:
|
||||
channel.isSending = false
|
||||
|
||||
while channel.sendQueue.len != 0:
|
||||
channel.sendQueue.keepItIf(not (it.fut.cancelled() and it.sent == 0))
|
||||
const NumBytesHeader = 12
|
||||
|
||||
while channel.sendQueue.len > 0:
|
||||
channel.sendQueue.keepItIf(not it.fut.finished())
|
||||
|
||||
if channel.sendWindow == 0:
|
||||
trace "trying to send while the sendWindow is empty"
|
||||
if channel.lengthSendQueueWithLimit() > channel.maxSendQueueSize:
|
||||
@@ -337,54 +351,57 @@ proc trySend(
|
||||
|
||||
let
|
||||
bytesAvailable = channel.lengthSendQueue()
|
||||
toSend = min(channel.sendWindow, bytesAvailable)
|
||||
numBytesToSend = min(channel.sendWindow, bytesAvailable)
|
||||
var
|
||||
sendBuffer = newSeqUninit[byte](toSend + 12)
|
||||
header = YamuxHeader.data(channel.id, toSend.uint32)
|
||||
sendBuffer = newSeqUninit[byte](NumBytesHeader + numBytesToSend)
|
||||
header = YamuxHeader.data(channel.id, numBytesToSend.uint32)
|
||||
inBuffer = 0
|
||||
|
||||
if toSend >= bytesAvailable and channel.closedLocally:
|
||||
trace "last buffer we'll sent on this channel", toSend, bytesAvailable
|
||||
if numBytesToSend >= bytesAvailable and channel.closedLocally:
|
||||
trace "last buffer we will send on this channel", numBytesToSend, bytesAvailable
|
||||
header.flags.incl({Fin})
|
||||
|
||||
sendBuffer[0 ..< 12] = header.encode()
|
||||
sendBuffer[0 ..< NumBytesHeader] = header.encode()
|
||||
|
||||
var futures: seq[Future[void].Raising([CancelledError, LPStreamError])]
|
||||
while inBuffer < toSend:
|
||||
while inBuffer < numBytesToSend:
|
||||
var toSend = channel.sendQueue[0]
|
||||
# concatenate the different message we try to send into one buffer
|
||||
let (data, sent, fut) = channel.sendQueue[0]
|
||||
let bufferToSend = min(data.len - sent, toSend - inBuffer)
|
||||
let bufferToSend = min(toSend.data.len - toSend.sent, numBytesToSend - inBuffer)
|
||||
|
||||
sendBuffer.toOpenArray(12, 12 + toSend - 1)[
|
||||
sendBuffer.toOpenArray(NumBytesHeader, NumBytesHeader + numBytesToSend - 1)[
|
||||
inBuffer ..< (inBuffer + bufferToSend)
|
||||
] = channel.sendQueue[0].data.toOpenArray(sent, sent + bufferToSend - 1)
|
||||
] = toSend.data.toOpenArray(toSend.sent, toSend.sent + bufferToSend - 1)
|
||||
|
||||
channel.sendQueue[0].sent.inc(bufferToSend)
|
||||
if channel.sendQueue[0].sent >= data.len:
|
||||
|
||||
if toSend.sent >= toSend.data.len:
|
||||
# if every byte of the message is in the buffer, add the write future to the
|
||||
# sequence of futures to be completed (or failed) when the buffer is sent
|
||||
futures.add(fut)
|
||||
futures.add(toSend.fut)
|
||||
channel.sendQueue.delete(0)
|
||||
|
||||
inBuffer.inc(bufferToSend)
|
||||
|
||||
trace "try to send the buffer", h = $header
|
||||
channel.sendWindow.dec(toSend)
|
||||
try:
|
||||
await channel.conn.write(sendBuffer)
|
||||
channel.sendWindow.dec(inBuffer)
|
||||
except CancelledError:
|
||||
trace "cancelled sending the buffer"
|
||||
for fut in futures.items():
|
||||
fut.cancelSoon()
|
||||
await channel.reset()
|
||||
break
|
||||
## Just for compiler. This should never happen as sendLoop is started by asyncSpawn.
|
||||
## Therefore, no one owns that sendLoop's future and no one can cancel it.
|
||||
discard
|
||||
except LPStreamError as exc:
|
||||
trace "failed to send the buffer"
|
||||
error "failed to send the buffer", description = exc.msg
|
||||
let connDown = newLPStreamConnDownError(exc)
|
||||
for fut in futures.items():
|
||||
for fut in futures:
|
||||
fut.fail(connDown)
|
||||
await channel.reset()
|
||||
break
|
||||
for fut in futures.items():
|
||||
|
||||
for fut in futures:
|
||||
fut.complete()
|
||||
|
||||
channel.activity = true
|
||||
|
||||
method write*(
|
||||
@@ -392,21 +409,29 @@ method write*(
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
|
||||
## Write to yamux channel
|
||||
##
|
||||
result = newFuture[void]("Yamux Send")
|
||||
var resFut = newFuture[void]("Yamux Send")
|
||||
|
||||
if channel.remoteReset:
|
||||
trace "stream is reset when write", channel = $channel
|
||||
result.fail(newLPStreamResetError())
|
||||
return result
|
||||
resFut.fail(newLPStreamResetError())
|
||||
return resFut
|
||||
|
||||
if channel.closedLocally or channel.isReset:
|
||||
result.fail(newLPStreamClosedError())
|
||||
return result
|
||||
resFut.fail(newLPStreamClosedError())
|
||||
return resFut
|
||||
|
||||
if msg.len == 0:
|
||||
result.complete()
|
||||
return result
|
||||
channel.sendQueue.add((msg, 0, result))
|
||||
resFut.complete()
|
||||
return resFut
|
||||
|
||||
channel.sendQueue.add(ToSend(data: msg, sent: 0, fut: resFut))
|
||||
|
||||
when defined(libp2p_yamux_metrics):
|
||||
libp2p_yamux_send_queue.observe(channel.lengthSendQueue().int64)
|
||||
asyncSpawn channel.trySend()
|
||||
|
||||
asyncSpawn channel.sendLoop()
|
||||
|
||||
return resFut
|
||||
|
||||
proc open(channel: YamuxChannel) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
## Open a yamux channel by sending a window update with Syn or Ack flag
|
||||
@@ -415,6 +440,8 @@ proc open(channel: YamuxChannel) {.async: (raises: [CancelledError, LPStreamErro
|
||||
trace "Try to open channel twice"
|
||||
return
|
||||
channel.opened = true
|
||||
channel.isReset = false
|
||||
|
||||
await channel.conn.write(
|
||||
YamuxHeader.windowUpdate(
|
||||
channel.id,
|
||||
@@ -488,6 +515,7 @@ proc createStream(
|
||||
stream.initStream()
|
||||
stream.peerId = m.connection.peerId
|
||||
stream.observedAddr = m.connection.observedAddr
|
||||
stream.localAddr = m.connection.localAddr
|
||||
stream.transportDir = m.connection.transportDir
|
||||
when defined(libp2p_agents_metrics):
|
||||
stream.shortAgent = m.connection.shortAgent
|
||||
@@ -502,18 +530,17 @@ method close*(m: Yamux) {.async: (raises: []).} =
|
||||
if m.isClosed == true:
|
||||
trace "Already closed"
|
||||
return
|
||||
m.isClosed = true
|
||||
|
||||
trace "Closing yamux"
|
||||
let channels = toSeq(m.channels.values())
|
||||
for channel in channels:
|
||||
for (d, s, fut) in channel.sendQueue:
|
||||
fut.fail(newLPStreamEOFError())
|
||||
channel.sendQueue = @[]
|
||||
channel.clearQueues(newLPStreamEOFError())
|
||||
channel.recvWindow = 0
|
||||
channel.sendWindow = 0
|
||||
channel.closedLocally = true
|
||||
channel.isReset = true
|
||||
channel.opened = false
|
||||
channel.isClosed = true
|
||||
await channel.remoteClosed()
|
||||
channel.receivedData.fire()
|
||||
try:
|
||||
@@ -523,6 +550,8 @@ method close*(m: Yamux) {.async: (raises: []).} =
|
||||
except LPStreamError as exc:
|
||||
trace "failed to send goAway", description = exc.msg
|
||||
await m.connection.close()
|
||||
|
||||
m.isClosed = true
|
||||
trace "Closed yamux"
|
||||
|
||||
proc handleStream(m: Yamux, channel: YamuxChannel) {.async: (raises: []).} =
|
||||
@@ -583,8 +612,10 @@ method handle*(m: Yamux) {.async: (raises: []).} =
|
||||
if header.length > 0:
|
||||
var buffer = newSeqUninit[byte](header.length)
|
||||
await m.connection.readExactly(addr buffer[0], int(header.length))
|
||||
do:
|
||||
raise newException(YamuxError, "Unknown stream ID: " & $header.streamId)
|
||||
|
||||
# If we do not have a stream, likely we sent a RST and/or closed the stream
|
||||
trace "unknown stream id", id = header.streamId
|
||||
|
||||
continue
|
||||
|
||||
let channel =
|
||||
@@ -600,7 +631,7 @@ method handle*(m: Yamux) {.async: (raises: []).} =
|
||||
|
||||
if header.msgType == WindowUpdate:
|
||||
channel.sendWindow += int(header.length)
|
||||
await channel.trySend()
|
||||
asyncSpawn channel.sendLoop()
|
||||
else:
|
||||
if header.length.int > channel.recvWindow.int:
|
||||
# check before allocating the buffer
|
||||
|
||||
@@ -52,6 +52,14 @@ func shortLog*(p: PeerInfo): auto =
|
||||
chronicles.formatIt(PeerInfo):
|
||||
shortLog(it)
|
||||
|
||||
proc expandAddrs*(
|
||||
p: PeerInfo
|
||||
): Future[seq[MultiAddress]] {.async: (raises: [CancelledError]).} =
|
||||
var addrs = p.listenAddrs
|
||||
for mapper in p.addressMappers:
|
||||
addrs = await mapper(addrs)
|
||||
addrs
|
||||
|
||||
proc update*(p: PeerInfo) {.async: (raises: [CancelledError]).} =
|
||||
p.addrs = p.listenAddrs
|
||||
for mapper in p.addressMappers:
|
||||
|
||||
@@ -214,7 +214,7 @@ proc identify*(
|
||||
info.agentVersion.get("").split("/")[0].safeToLowerAscii().get("")
|
||||
if KnownLibP2PAgentsSeq.contains(shortAgent):
|
||||
knownAgent = shortAgent
|
||||
muxer.connection.setShortAgent(knownAgent)
|
||||
muxer.setShortAgent(knownAgent)
|
||||
|
||||
peerStore.updatePeerInfo(info, stream.observedAddr)
|
||||
finally:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
import results
|
||||
import chronos, chronicles
|
||||
import ../../../switch, ../../../multiaddress, ../../../peerid
|
||||
import core
|
||||
import types
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonat"
|
||||
|
||||
@@ -20,9 +20,9 @@ import
|
||||
../../../peerid,
|
||||
../../../utils/[semaphore, future],
|
||||
../../../errors
|
||||
import core
|
||||
import types
|
||||
|
||||
export core
|
||||
export types
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonat"
|
||||
@@ -105,7 +105,7 @@ proc tryDial(
|
||||
autonat.sem.release()
|
||||
for f in futs:
|
||||
if not f.finished():
|
||||
f.cancel()
|
||||
f.cancelSoon()
|
||||
|
||||
proc handleDial(autonat: Autonat, conn: Connection, msg: AutonatMsg): Future[void] =
|
||||
let dial = msg.dial.valueOr:
|
||||
|
||||
@@ -14,11 +14,11 @@ import chronos, metrics
|
||||
import ../../../switch
|
||||
import ../../../wire
|
||||
import client
|
||||
from core import NetworkReachability, AutonatUnreachableError
|
||||
from types import NetworkReachability, AutonatUnreachableError
|
||||
import ../../../utils/heartbeat
|
||||
import ../../../crypto/crypto
|
||||
|
||||
export core.NetworkReachability
|
||||
export NetworkReachability
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonatservice"
|
||||
|
||||
279
libp2p/protocols/connectivity/autonatv2/server.nim
Normal file
279
libp2p/protocols/connectivity/autonatv2/server.nim
Normal file
@@ -0,0 +1,279 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import results
|
||||
import chronos, chronicles
|
||||
import
|
||||
../../../../libp2p/[
|
||||
switch,
|
||||
muxers/muxer,
|
||||
dialer,
|
||||
multiaddress,
|
||||
transports/transport,
|
||||
multicodec,
|
||||
peerid,
|
||||
protobuf/minprotobuf,
|
||||
utils/ipaddr,
|
||||
],
|
||||
../../protocol,
|
||||
./types
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonat v2 server"
|
||||
|
||||
type AutonatV2Config* = object
|
||||
dialTimeout: Duration
|
||||
dialDataSize: uint64
|
||||
amplificationAttackTimeout: Duration
|
||||
allowPrivateAddresses: bool
|
||||
|
||||
type AutonatV2* = ref object of LPProtocol
|
||||
switch*: Switch
|
||||
config: AutonatV2Config
|
||||
|
||||
proc new*(
|
||||
T: typedesc[AutonatV2Config],
|
||||
dialTimeout: Duration = DefaultDialTimeout,
|
||||
dialDataSize: uint64 = DefaultDialDataSize,
|
||||
amplificationAttackTimeout: Duration = DefaultAmplificationAttackDialTimeout,
|
||||
allowPrivateAddresses: bool = false,
|
||||
): T =
|
||||
T(
|
||||
dialTimeout: dialTimeout,
|
||||
dialDataSize: dialDataSize,
|
||||
amplificationAttackTimeout: amplificationAttackTimeout,
|
||||
allowPrivateAddresses: allowPrivateAddresses,
|
||||
)
|
||||
|
||||
proc sendDialResponse(
|
||||
conn: Connection,
|
||||
status: ResponseStatus,
|
||||
addrIdx: Opt[AddrIdx] = Opt.none(AddrIdx),
|
||||
dialStatus: Opt[DialStatus] = Opt.none(DialStatus),
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
await conn.writeLp(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialResponse,
|
||||
dialResp: DialResponse(status: status, addrIdx: addrIdx, dialStatus: dialStatus),
|
||||
).encode().buffer
|
||||
)
|
||||
|
||||
proc findObservedIPAddr*(
|
||||
conn: Connection, req: DialRequest
|
||||
): Future[Opt[MultiAddress]] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let observedAddr = conn.observedAddr.valueOr:
|
||||
await conn.sendDialResponse(ResponseStatus.EInternalError)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
let isRelayed = observedAddr.contains(multiCodec("p2p-circuit")).valueOr:
|
||||
error "Invalid observed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
if isRelayed:
|
||||
error "Invalid observed address: relayed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
let hostIp = observedAddr[0].valueOr:
|
||||
error "Invalid observed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EInternalError)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
return Opt.some(hostIp)
|
||||
|
||||
proc dialBack(
|
||||
conn: Connection, nonce: Nonce
|
||||
): Future[DialStatus] {.
|
||||
async: (raises: [CancelledError, DialFailedError, LPStreamError])
|
||||
.} =
|
||||
try:
|
||||
# send dial back
|
||||
await conn.writeLp(DialBack(nonce: nonce).encode().buffer)
|
||||
|
||||
# receive DialBackResponse
|
||||
let dialBackResp = DialBackResponse.decode(
|
||||
initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))
|
||||
).valueOr:
|
||||
error "DialBack failed, could not decode DialBackResponse"
|
||||
return DialStatus.EDialBackError
|
||||
except LPStreamRemoteClosedError as exc:
|
||||
# failed because of nonce error (remote reset the stream): EDialBackError
|
||||
error "DialBack failed, remote closed the connection", description = exc.msg
|
||||
return DialStatus.EDialBackError
|
||||
|
||||
# TODO: failed because of client or server resources: EDialError
|
||||
|
||||
trace "DialBack successful"
|
||||
return DialStatus.Ok
|
||||
|
||||
proc handleDialDataResponses(
|
||||
self: AutonatV2, conn: Connection
|
||||
) {.async: (raises: [CancelledError, AutonatV2Error, LPStreamError]).} =
|
||||
var dataReceived: uint64 = 0
|
||||
|
||||
while dataReceived < self.config.dialDataSize:
|
||||
let msg = AutonatV2Msg.decode(
|
||||
initProtoBuffer(await conn.readLp(AutonatV2DialDataResponseLpSize))
|
||||
).valueOr:
|
||||
raise newException(AutonatV2Error, "Received malformed message")
|
||||
debug "Received message", msgType = $msg.msgType
|
||||
if msg.msgType != MsgType.DialDataResponse:
|
||||
raise
|
||||
newException(AutonatV2Error, "Expecting DialDataResponse, got " & $msg.msgType)
|
||||
let resp = msg.dialDataResp
|
||||
dataReceived += resp.data.len.uint64
|
||||
debug "received data",
|
||||
dataReceived = resp.data.len.uint64, totalDataReceived = dataReceived
|
||||
|
||||
proc amplificationAttackPrevention(
|
||||
self: AutonatV2, conn: Connection, addrIdx: AddrIdx
|
||||
): Future[bool] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
# send DialDataRequest
|
||||
await conn.writeLp(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialDataRequest,
|
||||
dialDataReq: DialDataRequest(addrIdx: addrIdx, numBytes: self.config.dialDataSize),
|
||||
).encode().buffer
|
||||
)
|
||||
|
||||
# recieve DialDataResponses until we're satisfied
|
||||
try:
|
||||
if not await self.handleDialDataResponses(conn).withTimeout(self.config.dialTimeout):
|
||||
error "Amplification attack prevention timeout",
|
||||
timeout = self.config.amplificationAttackTimeout, peer = conn.peerId
|
||||
return false
|
||||
except AutonatV2Error as exc:
|
||||
error "Amplification attack prevention failed", description = exc.msg
|
||||
return false
|
||||
|
||||
return true
|
||||
|
||||
proc canDial(self: AutonatV2, addrs: MultiAddress): bool =
|
||||
let (ipv4Support, ipv6Support) = self.switch.peerInfo.listenAddrs.ipSupport()
|
||||
addrs[0].withValue(addrIp):
|
||||
if IP4.match(addrIp) and not ipv4Support:
|
||||
return false
|
||||
if IP6.match(addrIp) and not ipv6Support:
|
||||
return false
|
||||
try:
|
||||
if not self.config.allowPrivateAddresses and isPrivate($addrIp):
|
||||
return false
|
||||
except ValueError:
|
||||
warn "Unable to parse IP address, skipping", addrs = $addrIp
|
||||
return false
|
||||
for t in self.switch.transports:
|
||||
if t.handles(addrs):
|
||||
return true
|
||||
return false
|
||||
|
||||
proc forceNewConnection(
|
||||
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[Connection]] {.async: (raises: [CancelledError]).} =
|
||||
## Bypasses connManager to force a new connection to ``pid``
|
||||
## instead of reusing a preexistent one
|
||||
try:
|
||||
let mux = await self.switch.dialer.dialAndUpgrade(Opt.some(pid), addrs)
|
||||
if mux.isNil():
|
||||
return Opt.none(Connection)
|
||||
return Opt.some(
|
||||
await self.switch.negotiateStream(
|
||||
await mux.newStream(), @[$AutonatV2Codec.DialBack]
|
||||
)
|
||||
)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError:
|
||||
return Opt.none(Connection)
|
||||
|
||||
proc chooseDialAddr(
|
||||
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[(Connection, AddrIdx)]] {.async: (raises: [CancelledError]).} =
|
||||
for i, ma in addrs:
|
||||
if self.canDial(ma):
|
||||
debug "Trying to dial", chosenAddrs = ma, addrIdx = i
|
||||
let conn = (await self.forceNewConnection(pid, @[ma])).valueOr:
|
||||
return Opt.none((Connection, AddrIdx))
|
||||
return Opt.some((conn, i.AddrIdx))
|
||||
return Opt.none((Connection, AddrIdx))
|
||||
|
||||
proc handleDialRequest(
|
||||
self: AutonatV2, conn: Connection, req: DialRequest
|
||||
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let observedIPAddr = (await conn.findObservedIPAddr(req)).valueOr:
|
||||
error "Could not find observed IP address"
|
||||
return
|
||||
|
||||
let (dialBackConn, addrIdx) = (await self.chooseDialAddr(conn.peerId, req.addrs)).valueOr:
|
||||
error "No dialable addresses found"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return
|
||||
defer:
|
||||
await dialBackConn.close()
|
||||
|
||||
# if observed address for peer is not in address list to try
|
||||
# then we perform Amplification Attack Prevention
|
||||
if not ipAddrMatches(observedIPAddr, req.addrs):
|
||||
debug "Starting amplification attack prevention",
|
||||
observedIPAddr = observedIPAddr, testAddr = req.addrs[addrIdx]
|
||||
# send DialDataRequest and wait until dataReceived is enough
|
||||
if not await self.amplificationAttackPrevention(conn, addrIdx):
|
||||
return
|
||||
|
||||
debug "Sending DialBack",
|
||||
nonce = req.nonce, addrIdx = addrIdx, addr = req.addrs[addrIdx]
|
||||
|
||||
let dialStatus = await dialBackConn.dialBack(req.nonce)
|
||||
|
||||
await conn.sendDialResponse(
|
||||
ResponseStatus.Ok, addrIdx = Opt.some(addrIdx), dialStatus = Opt.some(dialStatus)
|
||||
)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[AutonatV2],
|
||||
switch: Switch,
|
||||
config: AutonatV2Config = AutonatV2Config.new(),
|
||||
): T =
|
||||
let autonatV2 = T(switch: switch, config: config)
|
||||
proc handleStream(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
defer:
|
||||
await conn.close()
|
||||
|
||||
let msg =
|
||||
try:
|
||||
AutonatV2Msg.decode(initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))).valueOr:
|
||||
trace "Unable to decode AutonatV2Msg"
|
||||
return
|
||||
except LPStreamError as exc:
|
||||
debug "Could not receive AutonatV2Msg", description = exc.msg
|
||||
return
|
||||
|
||||
debug "Received message", msgType = $msg.msgType
|
||||
if msg.msgType != MsgType.DialRequest:
|
||||
debug "Expecting DialRequest", receivedMsgType = msg.msgType
|
||||
return
|
||||
|
||||
try:
|
||||
await autonatV2.handleDialRequest(conn, msg.dialReq)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except LPStreamRemoteClosedError as exc:
|
||||
debug "Connection closed by peer", description = exc.msg, peer = conn.peerId
|
||||
except LPStreamError as exc:
|
||||
debug "Stream Error", description = exc.msg
|
||||
except DialFailedError as exc:
|
||||
debug "Could not dial peer", description = exc.msg, peer = conn.peerId
|
||||
|
||||
autonatV2.handler = handleStream
|
||||
autonatV2.codec = $AutonatV2Codec.DialRequest
|
||||
autonatV2
|
||||
262
libp2p/protocols/connectivity/autonatv2/types.nim
Normal file
262
libp2p/protocols/connectivity/autonatv2/types.nim
Normal file
@@ -0,0 +1,262 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import results, chronos, chronicles
|
||||
import
|
||||
../../../multiaddress, ../../../peerid, ../../../protobuf/minprotobuf, ../../../switch
|
||||
from ../autonat/types import NetworkReachability
|
||||
|
||||
const
|
||||
DefaultDialTimeout*: Duration = 15.seconds
|
||||
DefaultAmplificationAttackDialTimeout*: Duration = 3.seconds
|
||||
DefaultDialDataSize*: uint64 = 50 * 1024 # 50 KiB > 50 KB
|
||||
AutonatV2MsgLpSize*: int = 1024
|
||||
# readLp needs to receive more than 4096 bytes (since it's a DialDataResponse) + overhead
|
||||
AutonatV2DialDataResponseLpSize*: int = 5000
|
||||
|
||||
type
|
||||
AutonatV2Codec* {.pure.} = enum
|
||||
DialRequest = "/libp2p/autonat/2/dial-request"
|
||||
DialBack = "/libp2p/autonat/2/dial-back"
|
||||
|
||||
AutonatV2Response* = object
|
||||
reachability*: NetworkReachability
|
||||
dialResp*: DialResponse
|
||||
addrs*: Opt[MultiAddress]
|
||||
|
||||
AutonatV2Error* = object of LPError
|
||||
|
||||
Nonce* = uint64
|
||||
|
||||
AddrIdx* = uint32
|
||||
|
||||
NumBytes* = uint64
|
||||
|
||||
MsgType* {.pure.} = enum
|
||||
# DialBack and DialBackResponse are not defined as AutonatV2Msg as per the spec
|
||||
# likely because they are expected in response to some other message
|
||||
DialRequest
|
||||
DialResponse
|
||||
DialDataRequest
|
||||
DialDataResponse
|
||||
|
||||
ResponseStatus* {.pure.} = enum
|
||||
EInternalError = 0
|
||||
ERequestRejected = 100
|
||||
EDialRefused = 101
|
||||
Ok = 200
|
||||
|
||||
DialBackStatus* {.pure.} = enum
|
||||
Ok = 0
|
||||
|
||||
DialStatus* {.pure.} = enum
|
||||
Unused = 0
|
||||
EDialError = 100
|
||||
EDialBackError = 101
|
||||
Ok = 200
|
||||
|
||||
DialRequest* = object
|
||||
addrs*: seq[MultiAddress]
|
||||
nonce*: Nonce
|
||||
|
||||
DialResponse* = object
|
||||
status*: ResponseStatus
|
||||
addrIdx*: Opt[AddrIdx]
|
||||
dialStatus*: Opt[DialStatus]
|
||||
|
||||
DialBack* = object
|
||||
nonce*: Nonce
|
||||
|
||||
DialBackResponse* = object
|
||||
status*: DialBackStatus
|
||||
|
||||
DialDataRequest* = object
|
||||
addrIdx*: AddrIdx
|
||||
numBytes*: NumBytes
|
||||
|
||||
DialDataResponse* = object
|
||||
data*: seq[byte]
|
||||
|
||||
AutonatV2Msg* = object
|
||||
case msgType*: MsgType
|
||||
of MsgType.DialRequest:
|
||||
dialReq*: DialRequest
|
||||
of MsgType.DialResponse:
|
||||
dialResp*: DialResponse
|
||||
of MsgType.DialDataRequest:
|
||||
dialDataReq*: DialDataRequest
|
||||
of MsgType.DialDataResponse:
|
||||
dialDataResp*: DialDataResponse
|
||||
|
||||
# DialRequest
|
||||
proc encode*(dialReq: DialRequest): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
for ma in dialReq.addrs:
|
||||
encoded.write(1, ma.data.buffer)
|
||||
encoded.write(2, dialReq.nonce)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialRequest], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
addrs: seq[MultiAddress]
|
||||
nonce: Nonce
|
||||
if not ?pb.getRepeatedField(1, addrs).toOpt():
|
||||
return Opt.none(T)
|
||||
if not ?pb.getField(2, nonce).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(addrs: addrs, nonce: nonce))
|
||||
|
||||
# DialResponse
|
||||
proc encode*(dialResp: DialResponse): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialResp.status.uint)
|
||||
# minprotobuf casts uses float64 for fixed64 fields
|
||||
dialResp.addrIdx.withValue(addrIdx):
|
||||
encoded.write(2, addrIdx)
|
||||
dialResp.dialStatus.withValue(dialStatus):
|
||||
encoded.write(3, dialStatus.uint)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialResponse], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
status: uint
|
||||
addrIdx: AddrIdx
|
||||
dialStatus: uint
|
||||
|
||||
if not ?pb.getField(1, status).toOpt():
|
||||
return Opt.none(T)
|
||||
|
||||
var optAddrIdx = Opt.none(AddrIdx)
|
||||
if ?pb.getField(2, addrIdx).toOpt():
|
||||
optAddrIdx = Opt.some(addrIdx)
|
||||
|
||||
var optDialStatus = Opt.none(DialStatus)
|
||||
if ?pb.getField(3, dialStatus).toOpt():
|
||||
optDialStatus = Opt.some(cast[DialStatus](dialStatus))
|
||||
|
||||
Opt.some(
|
||||
T(
|
||||
status: cast[ResponseStatus](status),
|
||||
addrIdx: optAddrIdx,
|
||||
dialStatus: optDialStatus,
|
||||
)
|
||||
)
|
||||
|
||||
# DialBack
|
||||
proc encode*(dialBack: DialBack): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialBack.nonce)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialBack], pb: ProtoBuffer): Opt[T] =
|
||||
var nonce: Nonce
|
||||
if not ?pb.getField(1, nonce).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(nonce: nonce))
|
||||
|
||||
# DialBackResponse
|
||||
proc encode*(dialBackResp: DialBackResponse): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialBackResp.status.uint)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialBackResponse], pb: ProtoBuffer): Opt[T] =
|
||||
var status: uint
|
||||
if not ?pb.getField(1, status).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(status: cast[DialBackStatus](status)))
|
||||
|
||||
# DialDataRequest
|
||||
proc encode*(dialDataReq: DialDataRequest): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialDataReq.addrIdx)
|
||||
encoded.write(2, dialDataReq.numBytes)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialDataRequest], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
addrIdx: AddrIdx
|
||||
numBytes: NumBytes
|
||||
if not ?pb.getField(1, addrIdx).toOpt():
|
||||
return Opt.none(T)
|
||||
if not ?pb.getField(2, numBytes).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(addrIdx: addrIdx, numBytes: numBytes))
|
||||
|
||||
# DialDataResponse
|
||||
proc encode*(dialDataResp: DialDataResponse): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
encoded.write(1, dialDataResp.data)
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[DialDataResponse], pb: ProtoBuffer): Opt[T] =
|
||||
var data: seq[byte]
|
||||
if not ?pb.getField(1, data).toOpt():
|
||||
return Opt.none(T)
|
||||
Opt.some(T(data: data))
|
||||
|
||||
proc protoField(msgType: MsgType): int =
|
||||
case msgType
|
||||
of MsgType.DialRequest: 1.int
|
||||
of MsgType.DialResponse: 2.int
|
||||
of MsgType.DialDataRequest: 3.int
|
||||
of MsgType.DialDataResponse: 4.int
|
||||
|
||||
# AutonatV2Msg
|
||||
proc encode*(msg: AutonatV2Msg): ProtoBuffer =
|
||||
var encoded = initProtoBuffer()
|
||||
case msg.msgType
|
||||
of MsgType.DialRequest:
|
||||
encoded.write(MsgType.DialRequest.protoField, msg.dialReq.encode())
|
||||
of MsgType.DialResponse:
|
||||
encoded.write(MsgType.DialResponse.protoField, msg.dialResp.encode())
|
||||
of MsgType.DialDataRequest:
|
||||
encoded.write(MsgType.DialDataRequest.protoField, msg.dialDataReq.encode())
|
||||
of MsgType.DialDataResponse:
|
||||
encoded.write(MsgType.DialDataResponse.protoField, msg.dialDataResp.encode())
|
||||
encoded.finish()
|
||||
encoded
|
||||
|
||||
proc decode*(T: typedesc[AutonatV2Msg], pb: ProtoBuffer): Opt[T] =
|
||||
var
|
||||
msgTypeOrd: uint32
|
||||
msg: ProtoBuffer
|
||||
|
||||
if ?pb.getField(MsgType.DialRequest.protoField, msg).toOpt():
|
||||
let dialReq = DialRequest.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(AutonatV2Msg(msgType: MsgType.DialRequest, dialReq: dialReq))
|
||||
elif ?pb.getField(MsgType.DialResponse.protoField, msg).toOpt():
|
||||
let dialResp = DialResponse.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(AutonatV2Msg(msgType: MsgType.DialResponse, dialResp: dialResp))
|
||||
elif ?pb.getField(MsgType.DialDataRequest.protoField, msg).toOpt():
|
||||
let dialDataReq = DialDataRequest.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(AutonatV2Msg(msgType: MsgType.DialDataRequest, dialDataReq: dialDataReq))
|
||||
elif ?pb.getField(MsgType.DialDataResponse.protoField, msg).toOpt():
|
||||
let dialDataResp = DialDataResponse.decode(msg).valueOr:
|
||||
return Opt.none(AutonatV2Msg)
|
||||
Opt.some(
|
||||
AutonatV2Msg(msgType: MsgType.DialDataResponse, dialDataResp: dialDataResp)
|
||||
)
|
||||
else:
|
||||
Opt.none(AutonatV2Msg)
|
||||
|
||||
# Custom `==` is needed to compare since AutonatV2Msg is a case object
|
||||
proc `==`*(a, b: AutonatV2Msg): bool =
|
||||
a.msgType == b.msgType and a.encode() == b.encode()
|
||||
47
libp2p/protocols/connectivity/autonatv2/utils.nim
Normal file
47
libp2p/protocols/connectivity/autonatv2/utils.nim
Normal file
@@ -0,0 +1,47 @@
|
||||
{.push raises: [].}
|
||||
|
||||
import results
|
||||
import chronos
|
||||
import
|
||||
../../protocol,
|
||||
../../../switch,
|
||||
../../../multiaddress,
|
||||
../../../multicodec,
|
||||
../../../peerid,
|
||||
../../../protobuf/minprotobuf,
|
||||
../autonat/service,
|
||||
./types
|
||||
|
||||
proc asNetworkReachability*(self: DialResponse): NetworkReachability =
|
||||
if self.status == EInternalError:
|
||||
return Unknown
|
||||
if self.status == ERequestRejected:
|
||||
return Unknown
|
||||
if self.status == EDialRefused:
|
||||
return Unknown
|
||||
|
||||
# if got here it means a dial was attempted
|
||||
let dialStatus = self.dialStatus.valueOr:
|
||||
return Unknown
|
||||
if dialStatus == Unused:
|
||||
return Unknown
|
||||
if dialStatus == EDialError:
|
||||
return NotReachable
|
||||
if dialStatus == EDialBackError:
|
||||
return NotReachable
|
||||
return Reachable
|
||||
|
||||
proc asAutonatV2Response*(
|
||||
self: DialResponse, testAddrs: seq[MultiAddress]
|
||||
): AutonatV2Response =
|
||||
let addrIdx = self.addrIdx.valueOr:
|
||||
return AutonatV2Response(
|
||||
reachability: self.asNetworkReachability(),
|
||||
dialResp: self,
|
||||
addrs: Opt.none(MultiAddress),
|
||||
)
|
||||
AutonatV2Response(
|
||||
reachability: self.asNetworkReachability(),
|
||||
dialResp: self,
|
||||
addrs: Opt.some(testAddrs[addrIdx]),
|
||||
)
|
||||
@@ -422,6 +422,6 @@ method stop*(r: Relay): Future[void] {.async: (raises: [], raw: true).} =
|
||||
warn "Stopping relay without starting it"
|
||||
return fut
|
||||
r.started = false
|
||||
r.reservationLoop.cancel()
|
||||
r.reservationLoop.cancelSoon()
|
||||
r.reservationLoop = nil
|
||||
fut
|
||||
|
||||
@@ -31,7 +31,7 @@ type RelayTransport* = ref object of Transport
|
||||
|
||||
method start*(
|
||||
self: RelayTransport, ma: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError]).} =
|
||||
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
if self.selfRunning:
|
||||
trace "Relay transport already running"
|
||||
return
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import chronos
|
||||
|
||||
const
|
||||
IdLength* = 32 # 256-bit IDs
|
||||
k* = 20 # replication parameter
|
||||
DefaultReplic* = 20 ## replication parameter, aka `k` in the spec
|
||||
alpha* = 10 # concurrency parameter
|
||||
ttl* = 24.hours
|
||||
maxBuckets* = 256
|
||||
|
||||
const KadCodec* = "/ipfs/kad/1.0.0"
|
||||
|
||||
@@ -1,53 +1,451 @@
|
||||
import chronos
|
||||
import chronicles
|
||||
import sequtils
|
||||
import sets
|
||||
import ../../peerid
|
||||
import ./consts
|
||||
import ./xordistance
|
||||
import ./routingtable
|
||||
import ./lookupstate
|
||||
import ./requests
|
||||
import ./keys
|
||||
import ../protocol
|
||||
import ../../switch
|
||||
import ./protobuf
|
||||
import ../../switch
|
||||
import ../../multihash
|
||||
import ../../utils/heartbeat
|
||||
import std/[times, options, tables]
|
||||
import results
|
||||
|
||||
logScope:
|
||||
topics = "kad-dht"
|
||||
|
||||
type EntryKey* = object
|
||||
data: seq[byte]
|
||||
|
||||
proc init*(T: typedesc[EntryKey], inner: seq[byte]): EntryKey {.gcsafe, raises: [].} =
|
||||
EntryKey(data: inner)
|
||||
|
||||
type EntryValue* = object
|
||||
data*: seq[byte] # public because needed for tests
|
||||
|
||||
proc init*(
|
||||
T: typedesc[EntryValue], inner: seq[byte]
|
||||
): EntryValue {.gcsafe, raises: [].} =
|
||||
EntryValue(data: inner)
|
||||
|
||||
type TimeStamp* = object
|
||||
# Currently a string, because for some reason, that's what is chosen at the protobuf level
|
||||
# TODO: convert between RFC3339 strings and use of integers (i.e. the _correct_ way)
|
||||
ts*: string # only public because needed for tests
|
||||
|
||||
type EntryRecord* = object
|
||||
value*: EntryValue # only public because needed for tests
|
||||
time*: TimeStamp # only public because needed for tests
|
||||
|
||||
proc init*(
|
||||
T: typedesc[EntryRecord], value: EntryValue, time: Option[TimeStamp]
|
||||
): EntryRecord {.gcsafe, raises: [].} =
|
||||
EntryRecord(value: value, time: time.get(TimeStamp(ts: $times.now().utc)))
|
||||
|
||||
type LocalTable* = object
|
||||
entries*: Table[EntryKey, EntryRecord] # public because needed for tests
|
||||
|
||||
proc init(self: typedesc[LocalTable]): LocalTable {.raises: [].} =
|
||||
LocalTable()
|
||||
|
||||
type EntryCandidate* = object
|
||||
key*: EntryKey
|
||||
value*: EntryValue
|
||||
|
||||
type ValidatedEntry* = object
|
||||
key: EntryKey
|
||||
value: EntryValue
|
||||
|
||||
proc init*(
|
||||
T: typedesc[ValidatedEntry], key: EntryKey, value: EntryValue
|
||||
): ValidatedEntry {.gcsafe, raises: [].} =
|
||||
ValidatedEntry(key: key, value: value)
|
||||
|
||||
type EntryValidator* = ref object of RootObj
|
||||
method isValid*(
|
||||
self: EntryValidator, key: EntryKey, val: EntryValue
|
||||
): bool {.base, raises: [], gcsafe.} =
|
||||
doAssert(false, "unimplimented base method")
|
||||
|
||||
type EntrySelector* = ref object of RootObj
|
||||
method select*(
|
||||
self: EntrySelector, cand: EntryRecord, others: seq[EntryRecord]
|
||||
): Result[EntryRecord, string] {.base, raises: [], gcsafe.} =
|
||||
doAssert(false, "EntrySelection base not implemented")
|
||||
|
||||
type KadDHT* = ref object of LPProtocol
|
||||
switch: Switch
|
||||
rng: ref HmacDrbgContext
|
||||
rtable*: RoutingTable
|
||||
maintenanceLoop: Future[void]
|
||||
dataTable*: LocalTable
|
||||
entryValidator: EntryValidator
|
||||
entrySelector: EntrySelector
|
||||
|
||||
proc insert*(
|
||||
self: var LocalTable, value: sink ValidatedEntry, time: TimeStamp
|
||||
) {.raises: [].} =
|
||||
debug "local table insertion", key = value.key.data, value = value.value.data
|
||||
self.entries[value.key] = EntryRecord(value: value.value, time: time)
|
||||
|
||||
const MaxMsgSize = 4096
|
||||
# Forward declaration
|
||||
proc findNode*(
|
||||
kad: KadDHT, targetId: Key
|
||||
): Future[seq[PeerId]] {.async: (raises: [CancelledError]).}
|
||||
|
||||
proc sendFindNode(
|
||||
kad: KadDHT, peerId: PeerId, addrs: seq[MultiAddress], targetId: Key
|
||||
): Future[Message] {.
|
||||
async: (raises: [CancelledError, DialFailedError, ValueError, LPStreamError])
|
||||
.} =
|
||||
let conn =
|
||||
if addrs.len == 0:
|
||||
await kad.switch.dial(peerId, KadCodec)
|
||||
else:
|
||||
await kad.switch.dial(peerId, addrs, KadCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
|
||||
let msg = Message(msgType: MessageType.findNode, key: some(targetId.getBytes()))
|
||||
await conn.writeLp(msg.encode().buffer)
|
||||
|
||||
let reply = Message.decode(await conn.readLp(MaxMsgSize)).tryGet()
|
||||
if reply.msgType != MessageType.findNode:
|
||||
raise newException(ValueError, "unexpected message type in reply: " & $reply)
|
||||
|
||||
return reply
|
||||
|
||||
proc waitRepliesOrTimeouts(
|
||||
pendingFutures: Table[PeerId, Future[Message]]
|
||||
): Future[(seq[Message], seq[PeerId])] {.async: (raises: [CancelledError]).} =
|
||||
await allFutures(toSeq(pendingFutures.values))
|
||||
|
||||
var receivedReplies: seq[Message] = @[]
|
||||
var failedPeers: seq[PeerId] = @[]
|
||||
|
||||
for (peerId, replyFut) in pendingFutures.pairs:
|
||||
try:
|
||||
receivedReplies.add(await replyFut)
|
||||
except CatchableError:
|
||||
failedPeers.add(peerId)
|
||||
error "could not send find_node to peer", peerId, err = getCurrentExceptionMsg()
|
||||
|
||||
return (receivedReplies, failedPeers)
|
||||
|
||||
proc dispatchPutVal(
|
||||
kad: KadDHT, peer: PeerId, entry: ValidatedEntry
|
||||
): Future[void] {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let conn = await kad.switch.dial(peer, KadCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
let msg = Message(
|
||||
msgType: MessageType.putValue,
|
||||
record: some(Record(key: some(entry.key.data), value: some(entry.value.data))),
|
||||
)
|
||||
await conn.writeLp(msg.encode().buffer)
|
||||
|
||||
let reply = Message.decode(await conn.readLp(MaxMsgSize)).valueOr:
|
||||
# todo log this more meaningfully
|
||||
error "putValue reply decode fail", error = error, conn = conn
|
||||
return
|
||||
if reply != msg:
|
||||
error "unexpected change between msg and reply: ",
|
||||
msg = msg, reply = reply, conn = conn
|
||||
|
||||
proc putValue*(
|
||||
kad: KadDHT, entKey: EntryKey, value: EntryValue, timeout: Option[int]
|
||||
): Future[Result[void, string]] {.async: (raises: [CancelledError]), gcsafe.} =
|
||||
if not kad.entryValidator.isValid(entKey, value):
|
||||
return err("invalid key/value pair")
|
||||
|
||||
let others: seq[EntryRecord] =
|
||||
if entKey in kad.dataTable.entries:
|
||||
@[kad.dataTable.entries.getOrDefault(entKey)]
|
||||
else:
|
||||
@[]
|
||||
|
||||
let candAsRec = EntryRecord.init(value, none(TimeStamp))
|
||||
let confirmedRec = kad.entrySelector.select(candAsRec, others).valueOr:
|
||||
error "application provided selector error (local)", msg = error
|
||||
return err(error)
|
||||
trace "local putval", candidate = candAsRec, others = others, selected = confirmedRec
|
||||
let validEnt = ValidatedEntry.init(entKey, confirmedRec.value)
|
||||
|
||||
let peers = await kad.findNode(entKey.data.toKey())
|
||||
# We first prime the sends so the data is ready to go
|
||||
let rpcBatch = peers.mapIt(kad.dispatchPutVal(it, validEnt))
|
||||
# then we do the `move`, as insert takes the data as `sink`
|
||||
kad.dataTable.insert(validEnt, confirmedRec.time)
|
||||
try:
|
||||
# now that the all the data is where it needs to be in memory, we can dispatch the
|
||||
# RPCs
|
||||
await rpcBatch.allFutures().wait(chronos.seconds(timeout.get(5)))
|
||||
|
||||
# It's quite normal for the dispatch to timeout, as it would require all calls to get
|
||||
# their response. Downstream users may desire some sort of functionality in the
|
||||
# future to get rpc telemetry, but in the meantime, we just move on...
|
||||
except AsyncTimeoutError:
|
||||
discard
|
||||
return results.ok()
|
||||
|
||||
# Helper function forward declaration
|
||||
proc checkConvergence(state: LookupState, me: PeerId): bool {.raises: [], gcsafe.}
|
||||
|
||||
proc findNode*(
|
||||
kad: KadDHT, targetId: Key
|
||||
): Future[seq[PeerId]] {.async: (raises: [CancelledError]).} =
|
||||
## Node lookup. Iteratively search for the k closest peers to a target ID.
|
||||
## Not necessarily will return the target itself
|
||||
|
||||
#debug "findNode", target = target
|
||||
|
||||
var initialPeers = kad.rtable.findClosestPeers(targetId, DefaultReplic)
|
||||
var state = LookupState.init(targetId, initialPeers, kad.rtable.hasher)
|
||||
var addrTable: Table[PeerId, seq[MultiAddress]] =
|
||||
initTable[PeerId, seq[MultiAddress]]()
|
||||
|
||||
while not state.done:
|
||||
let toQuery = state.selectAlphaPeers()
|
||||
debug "queries", list = toQuery.mapIt(it.shortLog()), addrTab = addrTable
|
||||
var pendingFutures = initTable[PeerId, Future[Message]]()
|
||||
|
||||
# TODO: pending futures always empty here, no?
|
||||
for peer in toQuery.filterIt(
|
||||
kad.switch.peerInfo.peerId != it or pendingFutures.hasKey(it)
|
||||
):
|
||||
state.markPending(peer)
|
||||
|
||||
pendingFutures[peer] = kad
|
||||
.sendFindNode(peer, addrTable.getOrDefault(peer, @[]), targetId)
|
||||
.wait(chronos.seconds(5))
|
||||
|
||||
state.activeQueries.inc
|
||||
|
||||
let (successfulReplies, timedOutPeers) = await waitRepliesOrTimeouts(pendingFutures)
|
||||
|
||||
for msg in successfulReplies:
|
||||
for peer in msg.closerPeers:
|
||||
let pid = PeerId.init(peer.id)
|
||||
if not pid.isOk:
|
||||
error "PeerId init went bad. this is unusual", data = peer.id
|
||||
continue
|
||||
addrTable[pid.get()] = peer.addrs
|
||||
state.updateShortlist(
|
||||
msg,
|
||||
proc(p: PeerInfo) =
|
||||
discard kad.rtable.insert(p.peerId)
|
||||
# Nodes might return different addresses for a peer, so we append instead of replacing
|
||||
var existingAddresses =
|
||||
kad.switch.peerStore[AddressBook][p.peerId].toHashSet()
|
||||
for a in p.addrs:
|
||||
existingAddresses.incl(a)
|
||||
kad.switch.peerStore[AddressBook][p.peerId] = existingAddresses.toSeq()
|
||||
# TODO: add TTL to peerstore, otherwise we can spam it with junk
|
||||
,
|
||||
kad.rtable.hasher,
|
||||
)
|
||||
|
||||
for timedOut in timedOutPeers:
|
||||
state.markFailed(timedOut)
|
||||
|
||||
# Check for covergence: no active queries, and no other peers to be selected
|
||||
state.done = checkConvergence(state, kad.switch.peerInfo.peerId)
|
||||
|
||||
return state.selectClosestK()
|
||||
|
||||
proc findPeer*(
|
||||
kad: KadDHT, peer: PeerId
|
||||
): Future[Result[PeerInfo, string]] {.async: (raises: [CancelledError]).} =
|
||||
## Walks the key space until it finds candidate addresses for a peer Id
|
||||
|
||||
if kad.switch.peerInfo.peerId == peer:
|
||||
# Looking for yourself.
|
||||
return ok(kad.switch.peerInfo)
|
||||
|
||||
if kad.switch.isConnected(peer):
|
||||
# Return known info about already connected peer
|
||||
return ok(PeerInfo(peerId: peer, addrs: kad.switch.peerStore[AddressBook][peer]))
|
||||
|
||||
let foundNodes = await kad.findNode(peer.toKey())
|
||||
if not foundNodes.contains(peer):
|
||||
return err("peer not found")
|
||||
|
||||
return ok(PeerInfo(peerId: peer, addrs: kad.switch.peerStore[AddressBook][peer]))
|
||||
|
||||
proc checkConvergence(state: LookupState, me: PeerId): bool {.raises: [], gcsafe.} =
|
||||
let ready = state.activeQueries == 0
|
||||
let noNew = selectAlphaPeers(state).filterIt(me != it).len == 0
|
||||
return ready and noNew
|
||||
|
||||
proc bootstrap*(
|
||||
kad: KadDHT, bootstrapNodes: seq[PeerInfo]
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
for b in bootstrapNodes:
|
||||
try:
|
||||
await kad.switch.connect(b.peerId, b.addrs)
|
||||
debug "connected to bootstrap peer", peerId = b.peerId
|
||||
except DialFailedError as e:
|
||||
# at some point will want to bubble up a Result[void, SomeErrorEnum]
|
||||
error "failed to dial to bootstrap peer", peerId = b.peerId, error = e.msg
|
||||
continue
|
||||
|
||||
let msg =
|
||||
try:
|
||||
await kad.sendFindNode(b.peerId, b.addrs, kad.rtable.selfId).wait(
|
||||
chronos.seconds(5)
|
||||
)
|
||||
except CatchableError as e:
|
||||
debug "send find node exception during bootstrap",
|
||||
target = b.peerId, addrs = b.addrs, err = e.msg
|
||||
continue
|
||||
for peer in msg.closerPeers:
|
||||
let p = PeerId.init(peer.id).valueOr:
|
||||
debug "invalid peer id received", error = error
|
||||
continue
|
||||
discard kad.rtable.insert(p)
|
||||
try:
|
||||
kad.switch.peerStore[AddressBook][p] = peer.addrs
|
||||
except:
|
||||
error "this is here because an ergonomic means of keying into a table without exceptions is unknown"
|
||||
|
||||
# bootstrap node replied succesfully. Adding to routing table
|
||||
discard kad.rtable.insert(b.peerId)
|
||||
|
||||
let key = PeerId.random(kad.rng).valueOr:
|
||||
doAssert(false, "this should never happen")
|
||||
return
|
||||
discard await kad.findNode(key.toKey())
|
||||
info "bootstrap lookup complete"
|
||||
|
||||
proc refreshBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
|
||||
for i in 0 ..< kad.rtable.buckets.len:
|
||||
if kad.rtable.buckets[i].isStale():
|
||||
let randomKey = randomKeyInBucketRange(kad.rtable.selfId, i, kad.rng)
|
||||
discard await kad.findNode(randomKey)
|
||||
|
||||
proc maintainBuckets(kad: KadDHT) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "refresh buckets", 10.minutes:
|
||||
debug "TODO: implement bucket maintenance"
|
||||
heartbeat "refresh buckets", chronos.minutes(10):
|
||||
await kad.refreshBuckets()
|
||||
|
||||
proc new*(
|
||||
T: typedesc[KadDHT], switch: Switch, rng: ref HmacDrbgContext = newRng()
|
||||
T: typedesc[KadDHT],
|
||||
switch: Switch,
|
||||
validator: EntryValidator,
|
||||
entrySelector: EntrySelector,
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
): T {.raises: [].} =
|
||||
var rtable = RoutingTable.init(switch.peerInfo.peerId)
|
||||
let kad = T(rng: rng, switch: switch, rtable: rtable)
|
||||
var rtable = RoutingTable.init(switch.peerInfo.peerId.toKey(), Opt.none(XorDHasher))
|
||||
let kad = T(
|
||||
rng: rng,
|
||||
switch: switch,
|
||||
rtable: rtable,
|
||||
dataTable: LocalTable.init(),
|
||||
entryValidator: validator,
|
||||
entrySelector: entrySelector,
|
||||
)
|
||||
|
||||
kad.codec = KadCodec
|
||||
kad.handler = proc(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
while not conn.atEof:
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msg = Message.decode(buf).tryGet()
|
||||
|
||||
# TODO: handle msg.msgType
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError:
|
||||
error "could not handle request",
|
||||
peerId = conn.PeerId, err = getCurrentExceptionMsg()
|
||||
finally:
|
||||
defer:
|
||||
await conn.close()
|
||||
while not conn.atEof:
|
||||
let buf =
|
||||
try:
|
||||
await conn.readLp(MaxMsgSize)
|
||||
except LPStreamError as e:
|
||||
debug "Read error when handling kademlia RPC", conn = conn, err = e.msg
|
||||
return
|
||||
let msg = Message.decode(buf).valueOr:
|
||||
debug "msg decode error handling kademlia RPC", err = error
|
||||
return
|
||||
|
||||
case msg.msgType
|
||||
of MessageType.findNode:
|
||||
let targetIdBytes = msg.key.valueOr:
|
||||
error "findNode message without key data present", msg = msg, conn = conn
|
||||
return
|
||||
let targetId = PeerId.init(targetIdBytes).valueOr:
|
||||
error "findNode message without valid key data", msg = msg, conn = conn
|
||||
return
|
||||
let closerPeers = kad.rtable
|
||||
.findClosest(targetId.toKey(), DefaultReplic)
|
||||
# exclude the node requester because telling a peer about itself does not reduce the distance,
|
||||
.filterIt(it != conn.peerId.toKey())
|
||||
|
||||
let responsePb = encodeFindNodeReply(closerPeers, switch)
|
||||
try:
|
||||
await conn.writeLp(responsePb.buffer)
|
||||
except LPStreamError as e:
|
||||
debug "write error when writing kad find-node RPC reply",
|
||||
conn = conn, err = e.msg
|
||||
return
|
||||
|
||||
# Peer is useful. adding to rtable
|
||||
discard kad.rtable.insert(conn.peerId)
|
||||
of MessageType.putValue:
|
||||
let record = msg.record.valueOr:
|
||||
error "no record in message buffer", msg = msg, conn = conn
|
||||
return
|
||||
let (skey, svalue) =
|
||||
if record.key.isSome() and record.value.isSome():
|
||||
(record.key.unsafeGet(), record.value.unsafeGet())
|
||||
else:
|
||||
error "no key or no value in rpc buffer", msg = msg, conn = conn
|
||||
return
|
||||
let key = EntryKey.init(skey)
|
||||
let value = EntryValue.init(svalue)
|
||||
|
||||
# Value sanitisation done. Start insertion process
|
||||
if not kad.entryValidator.isValid(key, value):
|
||||
return
|
||||
|
||||
let others =
|
||||
if kad.dataTable.entries.contains(key):
|
||||
# need to do this shenans in order to avoid exceptions.
|
||||
@[kad.dataTable.entries.getOrDefault(key)]
|
||||
else:
|
||||
@[]
|
||||
let candRec = EntryRecord.init(value, none(TimeStamp))
|
||||
let selectedRec = kad.entrySelector.select(candRec, others).valueOr:
|
||||
error "application provided selector error", msg = error, conn = conn
|
||||
return
|
||||
trace "putval handler selection",
|
||||
cand = candRec, others = others, selected = selectedRec
|
||||
|
||||
# Assume that if selection goes with another value, that it is valid
|
||||
let validated = ValidatedEntry(key: key, value: selectedRec.value)
|
||||
|
||||
kad.dataTable.insert(validated, selectedRec.time)
|
||||
# consistent with following link, echo message without change
|
||||
# https://github.com/libp2p/js-libp2p/blob/cf9aab5c841ec08bc023b9f49083c95ad78a7a07/packages/kad-dht/src/rpc/handlers/put-value.ts#L22
|
||||
try:
|
||||
await conn.writeLp(buf)
|
||||
except LPStreamError as e:
|
||||
debug "write error when writing kad find-node RPC reply",
|
||||
conn = conn, err = e.msg
|
||||
return
|
||||
else:
|
||||
error "unhandled kad-dht message type", msg = msg
|
||||
return
|
||||
return kad
|
||||
|
||||
proc setSelector*(kad: KadDHT, selector: EntrySelector) =
|
||||
doAssert(selector != nil)
|
||||
kad.entrySelector = selector
|
||||
|
||||
proc setValidator*(kad: KadDHT, validator: EntryValidator) =
|
||||
doAssert(validator != nil)
|
||||
kad.entryValidator = validator
|
||||
|
||||
method start*(
|
||||
kad: KadDHT
|
||||
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
|
||||
@@ -65,10 +463,12 @@ method start*(
|
||||
fut
|
||||
|
||||
method stop*(kad: KadDHT): Future[void] {.async: (raises: [], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if not kad.started:
|
||||
return
|
||||
return fut
|
||||
|
||||
kad.started = false
|
||||
kad.maintenanceLoop.cancelSoon()
|
||||
kad.maintenanceLoop = nil
|
||||
return
|
||||
return fut
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import nimcrypto/sha2
|
||||
import ../../peerid
|
||||
import ./consts
|
||||
import chronicles
|
||||
import stew/byteutils
|
||||
|
||||
type
|
||||
KeyType* {.pure.} = enum
|
||||
Unhashed
|
||||
Raw
|
||||
PeerId
|
||||
|
||||
@@ -13,25 +12,26 @@ type
|
||||
case kind*: KeyType
|
||||
of KeyType.PeerId:
|
||||
peerId*: PeerId
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
data*: array[IdLength, byte]
|
||||
of KeyType.Raw:
|
||||
data*: seq[byte]
|
||||
|
||||
proc toKey*(s: seq[byte]): Key =
|
||||
doAssert s.len == IdLength
|
||||
var data: array[IdLength, byte]
|
||||
for i in 0 ..< IdLength:
|
||||
data[i] = s[i]
|
||||
return Key(kind: KeyType.Raw, data: data)
|
||||
return Key(kind: KeyType.Raw, data: s)
|
||||
|
||||
proc toKey*(p: PeerId): Key =
|
||||
return Key(kind: KeyType.PeerId, peerId: p)
|
||||
|
||||
proc toPeerId*(k: Key): PeerId {.raises: [ValueError].} =
|
||||
if k.kind != KeyType.PeerId:
|
||||
raise newException(ValueError, "not a peerId")
|
||||
k.peerId
|
||||
|
||||
proc getBytes*(k: Key): seq[byte] =
|
||||
return
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
k.peerId.getBytes()
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
of KeyType.Raw:
|
||||
@(k.data)
|
||||
|
||||
template `==`*(a, b: Key): bool =
|
||||
@@ -41,7 +41,7 @@ proc shortLog*(k: Key): string =
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
"PeerId:" & $k.peerId
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
of KeyType.Raw:
|
||||
$k.kind & ":" & toHex(k.data)
|
||||
|
||||
chronicles.formatIt(Key):
|
||||
|
||||
120
libp2p/protocols/kademlia/lookupstate.nim
Normal file
120
libp2p/protocols/kademlia/lookupstate.nim
Normal file
@@ -0,0 +1,120 @@
|
||||
import sequtils
|
||||
import ./consts
|
||||
import ./protobuf
|
||||
import ./xordistance
|
||||
import ./keys
|
||||
import ../../[peerid, peerinfo]
|
||||
import algorithm
|
||||
import chronicles
|
||||
|
||||
type
|
||||
LookupNode* = object
|
||||
peerId: PeerId
|
||||
distance: XorDistance
|
||||
queried: bool # have we already queried this node?
|
||||
pending: bool # is there an active request rn?
|
||||
failed: bool # did the query timeout or error?
|
||||
|
||||
LookupState* = object
|
||||
targetId: Key
|
||||
shortlist: seq[LookupNode] # current known closest node
|
||||
activeQueries*: int # how many queries in flight
|
||||
alpha: int # parallelism level
|
||||
repliCount: int ## aka `k` in the spec: number of closest nodes to find
|
||||
done*: bool # has lookup converged
|
||||
|
||||
proc alreadyInShortlist(state: LookupState, peer: Peer): bool =
|
||||
return state.shortlist.anyIt(it.peerId.getBytes() == peer.id)
|
||||
|
||||
proc updateShortlist*(
|
||||
state: var LookupState,
|
||||
msg: Message,
|
||||
onInsert: proc(p: PeerInfo) {.gcsafe.},
|
||||
hasher: Opt[XorDHasher],
|
||||
) =
|
||||
for newPeer in msg.closerPeers.filterIt(not alreadyInShortlist(state, it)):
|
||||
let peerInfo = PeerInfo(peerId: PeerId.init(newPeer.id).get(), addrs: newPeer.addrs)
|
||||
try:
|
||||
onInsert(peerInfo)
|
||||
state.shortlist.add(
|
||||
LookupNode(
|
||||
peerId: peerInfo.peerId,
|
||||
distance: xorDistance(peerInfo.peerId, state.targetId, hasher),
|
||||
queried: false,
|
||||
pending: false,
|
||||
failed: false,
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
debug "could not update shortlist", err = exc.msg
|
||||
|
||||
state.shortlist.sort(
|
||||
proc(a, b: LookupNode): int =
|
||||
cmp(a.distance, b.distance)
|
||||
)
|
||||
|
||||
state.activeQueries.dec
|
||||
|
||||
proc markFailed*(state: var LookupState, peerId: PeerId) =
|
||||
for p in mitems(state.shortlist):
|
||||
if p.peerId == peerId:
|
||||
p.failed = true
|
||||
p.pending = false
|
||||
p.queried = true
|
||||
state.activeQueries.dec
|
||||
break
|
||||
|
||||
proc markPending*(state: var LookupState, peerId: PeerId) =
|
||||
for p in mitems(state.shortlist):
|
||||
if p.peerId == peerId:
|
||||
p.pending = true
|
||||
p.queried = true
|
||||
break
|
||||
|
||||
proc selectAlphaPeers*(state: LookupState): seq[PeerId] =
|
||||
var selected: seq[PeerId] = @[]
|
||||
for p in state.shortlist:
|
||||
if not p.queried and not p.failed and not p.pending:
|
||||
selected.add(p.peerId)
|
||||
if selected.len >= state.alpha:
|
||||
break
|
||||
return selected
|
||||
|
||||
proc init*(
|
||||
T: type LookupState,
|
||||
targetId: Key,
|
||||
initialPeers: seq[PeerId],
|
||||
hasher: Opt[XorDHasher],
|
||||
): T =
|
||||
var res = LookupState(
|
||||
targetId: targetId,
|
||||
shortlist: @[],
|
||||
activeQueries: 0,
|
||||
alpha: alpha,
|
||||
repliCount: DefaultReplic,
|
||||
done: false,
|
||||
)
|
||||
for p in initialPeers:
|
||||
res.shortlist.add(
|
||||
LookupNode(
|
||||
peerId: p,
|
||||
distance: xorDistance(p, targetId, hasher),
|
||||
queried: false,
|
||||
pending: false,
|
||||
failed: false,
|
||||
)
|
||||
)
|
||||
|
||||
res.shortlist.sort(
|
||||
proc(a, b: LookupNode): int =
|
||||
cmp(a.distance, b.distance)
|
||||
)
|
||||
return res
|
||||
|
||||
proc selectClosestK*(state: LookupState): seq[PeerId] =
|
||||
var res: seq[PeerId] = @[]
|
||||
for p in state.shortlist.filterIt(not it.failed):
|
||||
res.add(p.peerId)
|
||||
if res.len >= state.repliCount:
|
||||
break
|
||||
return res
|
||||
@@ -39,9 +39,11 @@ type
|
||||
closerPeers*: seq[Peer]
|
||||
providerPeers*: seq[Peer]
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: Record) {.raises: [].}
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: Record) {.raises: [], gcsafe.}
|
||||
|
||||
proc writeOpt*[T](pb: var ProtoBuffer, field: int, opt: Option[T]) {.raises: [].}
|
||||
proc writeOpt*[T](
|
||||
pb: var ProtoBuffer, field: int, opt: Option[T]
|
||||
) {.raises: [], gcsafe.}
|
||||
|
||||
proc encode*(record: Record): ProtoBuffer {.raises: [].} =
|
||||
var pb = initProtoBuffer()
|
||||
@@ -60,7 +62,7 @@ proc encode*(peer: Peer): ProtoBuffer {.raises: [].} =
|
||||
pb.finish()
|
||||
return pb
|
||||
|
||||
proc encode*(msg: Message): ProtoBuffer {.raises: [].} =
|
||||
proc encode*(msg: Message): ProtoBuffer {.raises: [], gcsafe.} =
|
||||
var pb = initProtoBuffer()
|
||||
|
||||
pb.write(1, uint32(ord(msg.msgType)))
|
||||
@@ -80,11 +82,13 @@ proc encode*(msg: Message): ProtoBuffer {.raises: [].} =
|
||||
|
||||
return pb
|
||||
|
||||
proc writeOpt*[T](pb: var ProtoBuffer, field: int, opt: Option[T]) {.raises: [].} =
|
||||
proc writeOpt*[T](
|
||||
pb: var ProtoBuffer, field: int, opt: Option[T]
|
||||
) {.raises: [], gcsafe.} =
|
||||
opt.withValue(v):
|
||||
pb.write(field, v)
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: Record) {.raises: [].} =
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: Record) {.raises: [], gcsafe.} =
|
||||
pb.write(field, value.encode())
|
||||
|
||||
proc getOptionField[T: ProtoScalar | string | seq[byte]](
|
||||
@@ -120,7 +124,7 @@ proc decode*(T: type Peer, pb: ProtoBuffer): ProtoResult[Option[T]] =
|
||||
|
||||
return ok(some(p))
|
||||
|
||||
proc decode*(T: type Message, buf: seq[byte]): ProtoResult[Option[T]] =
|
||||
proc decode*(T: type Message, buf: seq[byte]): ProtoResult[T] =
|
||||
var
|
||||
m: Message
|
||||
key: seq[byte]
|
||||
@@ -156,4 +160,4 @@ proc decode*(T: type Message, buf: seq[byte]): ProtoResult[Option[T]] =
|
||||
peer.withValue(peer):
|
||||
m.providerPeers.add(peer)
|
||||
|
||||
return ok(some(m))
|
||||
return ok(m)
|
||||
|
||||
34
libp2p/protocols/kademlia/requests.nim
Normal file
34
libp2p/protocols/kademlia/requests.nim
Normal file
@@ -0,0 +1,34 @@
|
||||
import ../../peerid
|
||||
import ../../switch
|
||||
import ../../peerstore
|
||||
import ./protobuf
|
||||
import ../../protobuf/minprotobuf
|
||||
import ./keys
|
||||
|
||||
proc encodeFindNodeReply*(
|
||||
closerPeers: seq[Key], switch: Switch
|
||||
): ProtoBuffer {.raises: [].} =
|
||||
var msg: Message
|
||||
msg.msgType = MessageType.findNode
|
||||
for peer in closerPeers:
|
||||
let peer =
|
||||
try:
|
||||
peer.toPeerId()
|
||||
except ValueError:
|
||||
continue
|
||||
let addrs = switch.peerStore[AddressBook][peer]
|
||||
if addrs.len == 0:
|
||||
continue
|
||||
|
||||
let p = Peer(
|
||||
id: peer.getBytes(),
|
||||
addrs: addrs,
|
||||
connection:
|
||||
# TODO: this should likely be optional as it can reveal the network graph of a node
|
||||
if switch.isConnected(peer):
|
||||
ConnectionType.connected
|
||||
else:
|
||||
ConnectionType.notConnected,
|
||||
)
|
||||
msg.closerPeers.add(p)
|
||||
return msg.encode()
|
||||
@@ -8,6 +8,7 @@ import ./xordistance
|
||||
import ../../peerid
|
||||
import sequtils
|
||||
import ../../utils/sequninit
|
||||
import results
|
||||
|
||||
logScope:
|
||||
topics = "kad-dht rtable"
|
||||
@@ -23,12 +24,16 @@ type
|
||||
RoutingTable* = ref object
|
||||
selfId*: Key
|
||||
buckets*: seq[Bucket]
|
||||
hasher*: Opt[XorDHasher]
|
||||
|
||||
proc init*(T: typedesc[RoutingTable], selfId: Key): T =
|
||||
return RoutingTable(selfId: selfId, buckets: @[])
|
||||
proc `$`*(rt: RoutingTable): string =
|
||||
"selfId(" & $rt.selfId & ") buckets(" & $rt.buckets & ")"
|
||||
|
||||
proc bucketIndex*(selfId, key: Key): int =
|
||||
return xorDistance(selfId, key).leadingZeros
|
||||
proc init*(T: typedesc[RoutingTable], selfId: Key, hasher: Opt[XorDHasher]): T =
|
||||
return RoutingTable(selfId: selfId, buckets: @[], hasher: hasher)
|
||||
|
||||
proc bucketIndex*(selfId, key: Key, hasher: Opt[XorDHasher]): int =
|
||||
return xorDistance(selfId, key, hasher).leadingZeros
|
||||
|
||||
proc peerIndexInBucket(bucket: var Bucket, nodeId: Key): Opt[int] =
|
||||
for i, p in bucket.peers:
|
||||
@@ -40,7 +45,7 @@ proc insert*(rtable: var RoutingTable, nodeId: Key): bool =
|
||||
if nodeId == rtable.selfId:
|
||||
return false # No self insertion
|
||||
|
||||
let idx = bucketIndex(rtable.selfId, nodeId)
|
||||
let idx = bucketIndex(rtable.selfId, nodeId, rtable.hasher)
|
||||
if idx >= maxBuckets:
|
||||
trace "cannot insert node. max buckets have been reached",
|
||||
nodeId, bucketIdx = idx, maxBuckets
|
||||
@@ -54,12 +59,12 @@ proc insert*(rtable: var RoutingTable, nodeId: Key): bool =
|
||||
let keyx = peerIndexInBucket(bucket, nodeId)
|
||||
if keyx.isSome:
|
||||
bucket.peers[keyx.unsafeValue].lastSeen = Moment.now()
|
||||
elif bucket.peers.len < k:
|
||||
elif bucket.peers.len < DefaultReplic:
|
||||
bucket.peers.add(NodeEntry(nodeId: nodeId, lastSeen: Moment.now()))
|
||||
else:
|
||||
# TODO: eviction policy goes here, rn we drop the node
|
||||
trace "cannot insert node in bucket, dropping node",
|
||||
nodeId, bucket = k, bucketIdx = idx
|
||||
nodeId, bucket = DefaultReplic, bucketIdx = idx
|
||||
return false
|
||||
|
||||
rtable.buckets[idx] = bucket
|
||||
@@ -77,7 +82,9 @@ proc findClosest*(rtable: RoutingTable, targetId: Key, count: int): seq[Key] =
|
||||
|
||||
allNodes.sort(
|
||||
proc(a, b: Key): int =
|
||||
cmp(xorDistance(a, targetId), xorDistance(b, targetId))
|
||||
cmp(
|
||||
xorDistance(a, targetId, rtable.hasher), xorDistance(b, targetId, rtable.hasher)
|
||||
)
|
||||
)
|
||||
|
||||
return allNodes[0 ..< min(count, allNodes.len)]
|
||||
|
||||
@@ -1,9 +1,27 @@
|
||||
import ./consts
|
||||
import stew/arrayOps
|
||||
import ./keys
|
||||
import nimcrypto/sha2
|
||||
import ../../peerid
|
||||
import results
|
||||
|
||||
type XorDistance* = array[IdLength, byte]
|
||||
type XorDHasher* = proc(input: seq[byte]): array[IdLength, byte] {.
|
||||
raises: [], nimcall, noSideEffect, gcsafe
|
||||
.}
|
||||
|
||||
proc defaultHasher(
|
||||
input: seq[byte]
|
||||
): array[IdLength, byte] {.raises: [], nimcall, noSideEffect, gcsafe.} =
|
||||
return sha256.digest(input).data
|
||||
|
||||
# useful for testing purposes
|
||||
proc noOpHasher*(
|
||||
input: seq[byte]
|
||||
): array[IdLength, byte] {.raises: [], nimcall, noSideEffect, gcsafe.} =
|
||||
var data: array[IdLength, byte]
|
||||
discard data.copyFrom(input)
|
||||
return data
|
||||
|
||||
proc countLeadingZeroBits*(b: byte): int =
|
||||
for i in 0 .. 7:
|
||||
@@ -31,25 +49,23 @@ proc `<`*(a, b: XorDistance): bool =
|
||||
proc `<=`*(a, b: XorDistance): bool =
|
||||
cmp(a, b) <= 0
|
||||
|
||||
proc hashFor(k: Key): seq[byte] =
|
||||
proc hashFor(k: Key, hasher: Opt[XorDHasher]): seq[byte] =
|
||||
return
|
||||
@(
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
sha256.digest(k.peerId.getBytes()).data
|
||||
hasher.get(defaultHasher)(k.peerId.getBytes())
|
||||
of KeyType.Raw:
|
||||
sha256.digest(k.data).data
|
||||
of KeyType.Unhashed:
|
||||
k.data
|
||||
hasher.get(defaultHasher)(k.data)
|
||||
)
|
||||
|
||||
proc xorDistance*(a, b: Key): XorDistance =
|
||||
let hashA = a.hashFor()
|
||||
let hashB = b.hashFor()
|
||||
proc xorDistance*(a, b: Key, hasher: Opt[XorDHasher]): XorDistance =
|
||||
let hashA = a.hashFor(hasher)
|
||||
let hashB = b.hashFor(hasher)
|
||||
var response: XorDistance
|
||||
for i in 0 ..< hashA.len:
|
||||
response[i] = hashA[i] xor hashB[i]
|
||||
return response
|
||||
|
||||
proc xorDistance*(a: PeerId, b: Key): XorDistance =
|
||||
xorDistance(a.toKey(), b)
|
||||
proc xorDistance*(a: PeerId, b: Key, hasher: Opt[XorDHasher]): XorDistance =
|
||||
xorDistance(a.toKey(), b, hasher)
|
||||
|
||||
@@ -57,7 +57,8 @@ proc perf*(
|
||||
statsCopy.uploadBytes += toWrite.uint
|
||||
p.stats = statsCopy
|
||||
|
||||
await conn.close()
|
||||
# Close write side of the stream (half-close) to signal EOF to server
|
||||
await conn.closeWrite()
|
||||
|
||||
size = sizeToRead
|
||||
|
||||
@@ -71,6 +72,9 @@ proc perf*(
|
||||
statsCopy.duration = Moment.now() - start
|
||||
statsCopy.downloadBytes += toRead.uint
|
||||
p.stats = statsCopy
|
||||
|
||||
# Close the connection after reading
|
||||
await conn.close()
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
except LPStreamError as e:
|
||||
|
||||
@@ -24,28 +24,29 @@ type Perf* = ref object of LPProtocol
|
||||
|
||||
proc new*(T: typedesc[Perf]): T {.public.} =
|
||||
var p = T()
|
||||
|
||||
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
|
||||
var bytesRead = 0
|
||||
try:
|
||||
trace "Received benchmark performance check", conn
|
||||
var
|
||||
sizeBuffer: array[8, byte]
|
||||
size: uint64
|
||||
await conn.readExactly(addr sizeBuffer[0], 8)
|
||||
size = uint64.fromBytesBE(sizeBuffer)
|
||||
|
||||
var toReadBuffer: array[PerfSize, byte]
|
||||
try:
|
||||
while true:
|
||||
bytesRead += await conn.readOnce(addr toReadBuffer[0], PerfSize)
|
||||
except CatchableError as exc:
|
||||
discard
|
||||
var uploadSizeBuffer: array[8, byte]
|
||||
await conn.readExactly(addr uploadSizeBuffer[0], 8)
|
||||
var uploadSize = uint64.fromBytesBE(uploadSizeBuffer)
|
||||
|
||||
var buf: array[PerfSize, byte]
|
||||
while size > 0:
|
||||
let toWrite = min(size, PerfSize)
|
||||
await conn.write(buf[0 ..< toWrite])
|
||||
size -= toWrite
|
||||
var readBuffer: array[PerfSize, byte]
|
||||
while not conn.atEof:
|
||||
try:
|
||||
let readBytes = await conn.readOnce(addr readBuffer[0], PerfSize)
|
||||
if readBytes == 0:
|
||||
break
|
||||
except LPStreamEOFError:
|
||||
break
|
||||
|
||||
var writeBuffer: array[PerfSize, byte]
|
||||
while uploadSize > 0:
|
||||
let toWrite = min(uploadSize, PerfSize)
|
||||
await conn.write(writeBuffer[0 ..< toWrite])
|
||||
uploadSize -= toWrite
|
||||
except CancelledError as exc:
|
||||
trace "cancelled perf handler"
|
||||
raise exc
|
||||
|
||||
@@ -29,10 +29,13 @@ import
|
||||
../../utility,
|
||||
../../switch
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
import ./bandwidth
|
||||
|
||||
import results
|
||||
export results
|
||||
|
||||
import ./gossipsub/[types, scoring, behavior], ../../utils/heartbeat
|
||||
import ./gossipsub/[types, scoring, behavior, preamblestore], ../../utils/heartbeat
|
||||
|
||||
export types, scoring, behavior, pubsub
|
||||
|
||||
@@ -51,6 +54,10 @@ declareCounter(
|
||||
declareCounter(
|
||||
libp2p_gossipsub_idontwant_saved_messages, "number of duplicates avoided by idontwant"
|
||||
)
|
||||
declareCounter(
|
||||
libp2p_gossipsub_imreceiving_saved_messages,
|
||||
"number of duplicates avoided by imreceiving",
|
||||
)
|
||||
declareCounter(
|
||||
libp2p_gossipsub_saved_bytes,
|
||||
"bytes saved by gossipsub optimizations",
|
||||
@@ -222,6 +229,10 @@ method init*(g: GossipSub) =
|
||||
raise exc
|
||||
|
||||
g.handler = handler
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
g.codecs &= GossipSubCodec_14
|
||||
|
||||
g.codecs &= GossipSubCodec_12
|
||||
g.codecs &= GossipSubCodec_11
|
||||
g.codecs &= GossipSubCodec_10
|
||||
@@ -240,6 +251,9 @@ method onNewPeer*(g: GossipSub, peer: PubSubPeer) =
|
||||
peer.iHaveBudget = IHavePeerBudget
|
||||
peer.pingBudget = PingsPeerBudget
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
peer.preambleBudget = PreamblePeerBudget
|
||||
|
||||
method onPubSubPeerEvent*(
|
||||
p: GossipSub, peer: PubSubPeer, event: PubSubPeerEvent
|
||||
) {.gcsafe.} =
|
||||
@@ -346,11 +360,14 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
||||
|
||||
var respControl: ControlMessage
|
||||
g.handleIDontWant(peer, control.idontwant)
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
g.handlePreamble(peer, control.preamble)
|
||||
g.handleIMReceiving(peer, control.imreceiving)
|
||||
let iwant = g.handleIHave(peer, control.ihave)
|
||||
if iwant.messageIDs.len > 0:
|
||||
respControl.iwant.add(iwant)
|
||||
respControl.prune.add(g.handleGraft(peer, control.graft))
|
||||
let messages = g.handleIWant(peer, control.iwant)
|
||||
let (messages, msgIDs) = g.handleIWant(peer, control.iwant)
|
||||
|
||||
let
|
||||
isPruneNotEmpty = respControl.prune.len > 0
|
||||
@@ -371,13 +388,34 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
||||
g.send(peer, RPCMsg(control: some(respControl)), isHighPriority = true)
|
||||
|
||||
if messages.len > 0:
|
||||
for smsg in messages:
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
var preambles: seq[ControlPreamble]
|
||||
|
||||
for i, smsg in messages:
|
||||
let topic = smsg.topic
|
||||
if g.knownTopics.contains(topic):
|
||||
libp2p_pubsub_broadcast_messages.inc(labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_messages.inc(labelValues = ["generic"])
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
# should we send preamble here? (Not in specs so far)
|
||||
# So receiver will send IMReciving only for preambles received from mesh members
|
||||
preambles.add(
|
||||
ControlPreamble(
|
||||
topicID: smsg.topic,
|
||||
messageID: msgIDs[i],
|
||||
messageLength: smsg.data.len.uint32,
|
||||
)
|
||||
)
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
g.broadcast(
|
||||
@[peer],
|
||||
RPCMsg(control: some(ControlMessage(preamble: preambles))),
|
||||
isHighPriority = true,
|
||||
)
|
||||
|
||||
# iwant replies have lower priority
|
||||
trace "sending iwant reply messages", peer
|
||||
g.send(peer, RPCMsg(messages: messages), isHighPriority = false)
|
||||
@@ -411,6 +449,34 @@ proc sendIDontWant(
|
||||
isHighPriority = true,
|
||||
)
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
const preambleMessageSizeThreshold* = 40 * 1024 # 40KiB
|
||||
|
||||
proc sendPreamble(
|
||||
g: GossipSub, msg: Message, msgId: MessageId, toSendPeers: var HashSet[PubSubPeer]
|
||||
) =
|
||||
if msg.data.len < preambleMessageSizeThreshold:
|
||||
return
|
||||
|
||||
g.broadcast(
|
||||
toSendPeers.filterIt(it.codec == GossipSubCodec_14),
|
||||
RPCMsg(
|
||||
control: some(
|
||||
ControlMessage(
|
||||
preamble:
|
||||
@[
|
||||
ControlPreamble(
|
||||
topicID: msg.topic,
|
||||
messageID: msgId,
|
||||
messageLength: msg.data.len.uint32,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
),
|
||||
isHighPriority = true,
|
||||
)
|
||||
|
||||
const iDontWantMessageSizeThreshold* = 512
|
||||
|
||||
proc isLargeMessage(msg: Message, msgId: MessageId): bool =
|
||||
@@ -489,6 +555,28 @@ proc validateAndRelay(
|
||||
|
||||
toSendPeers.exclIfIt(isMsgInIdontWant(it))
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
proc isMsgInIMReceiving(it: PubSubPeer): bool =
|
||||
if it.heIsReceivings.hasKey(msgId):
|
||||
libp2p_gossipsub_imreceiving_saved_messages.inc
|
||||
return true
|
||||
return false
|
||||
|
||||
proc deferSend(deferPeers: HashSet[PubSubPeer]) {.async.} =
|
||||
let receiveTimeMs = calculateReceiveTimeMs(msg.data.len)
|
||||
await sleepAsync(receiveTimeMs.milliseconds)
|
||||
for deferPeer in deferPeers:
|
||||
if not deferPeer.isMsgInIdontWant:
|
||||
#No need to send preamble at timeout
|
||||
g.broadcast(@[deferPeer], RPCMsg(messages: @[msg]), isHighPriority = false)
|
||||
|
||||
let allPeers = toSendPeers
|
||||
toSendPeers.exclIfIt(isMsgInIMReceiving(it))
|
||||
g.sendPreamble(msg, msgId, toSendPeers)
|
||||
if not PullOperation:
|
||||
let receivingPeers = allPeers - toSendPeers
|
||||
asyncSpawn deferSend(receivingPeers)
|
||||
|
||||
# In theory, if topics are the same in all messages, we could batch - we'd
|
||||
# also have to be careful to only include validated messages
|
||||
g.broadcast(toSendPeers, RPCMsg(messages: @[msg]), isHighPriority = false)
|
||||
@@ -602,6 +690,14 @@ method rpcHandler*(
|
||||
msgId = msgIdResult.get
|
||||
msgIdSalted = g.salt(msgId)
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
if msg.data.len > preambleMessageSizeThreshold:
|
||||
g.ongoingReceives.del(msgId)
|
||||
g.ongoingIWantReceives.del(msgId)
|
||||
var startTime: Moment
|
||||
if peer.heIsSendings.pop(msgId, startTime):
|
||||
peer.bandwidthTracking.download.update(startTime, msg.data.len)
|
||||
|
||||
if g.addSeen(msgIdSalted):
|
||||
trace "Dropping already-seen message", msgId = shortLog(msgId), peer
|
||||
|
||||
@@ -629,9 +725,8 @@ method rpcHandler*(
|
||||
continue
|
||||
|
||||
if (msg.signature.len > 0 or g.verifySignature) and not msg.verify():
|
||||
# always validate if signature is present or required
|
||||
debug "Dropping message due to failed signature verification",
|
||||
msgId = shortLog(msgId), peer
|
||||
debug "Dropping message due to failed signature verification", msg = msg
|
||||
|
||||
await g.punishInvalidMessage(peer, msg)
|
||||
continue
|
||||
|
||||
@@ -794,7 +889,7 @@ method publish*(
|
||||
|
||||
let pubParams = publishParams.get(PublishParams())
|
||||
|
||||
let peers =
|
||||
var peers =
|
||||
if pubParams.useCustomConn:
|
||||
g.makePeersForPublishUsingCustomConn(topic)
|
||||
else:
|
||||
@@ -836,8 +931,12 @@ method publish*(
|
||||
if not pubParams.skipMCache:
|
||||
g.mcache.put(msgId, msg)
|
||||
|
||||
if g.parameters.sendIDontWantOnPublish and isLargeMessage(msg, msgId):
|
||||
g.sendIDontWant(msg, msgId, peers)
|
||||
if g.parameters.sendIDontWantOnPublish:
|
||||
if isLargeMessage(msg, msgId):
|
||||
g.sendIDontWant(msg, msgId, peers)
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
g.sendPreamble(msg, msgId, peers)
|
||||
|
||||
g.broadcast(
|
||||
peers,
|
||||
@@ -898,6 +997,8 @@ method start*(
|
||||
g.heartbeatFut = g.heartbeat()
|
||||
g.scoringHeartbeatFut = g.scoringHeartbeat()
|
||||
g.directPeersLoop = g.maintainDirectPeers()
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
g.preambleExpirationFut = g.preambleExpirationHeartbeat()
|
||||
g.started = true
|
||||
fut
|
||||
|
||||
@@ -912,6 +1013,9 @@ method stop*(g: GossipSub): Future[void] {.async: (raises: [], raw: true).} =
|
||||
return fut
|
||||
|
||||
# stop heartbeat interval
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
g.preambleExpirationFut.cancelSoon()
|
||||
|
||||
g.directPeersLoop.cancelSoon()
|
||||
g.scoringHeartbeatFut.cancelSoon()
|
||||
g.heartbeatFut.cancelSoon()
|
||||
|
||||
@@ -24,6 +24,9 @@ import
|
||||
signed_envelope,
|
||||
utils/heartbeat,
|
||||
]
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
import ./preamblestore
|
||||
import ../bandwidth
|
||||
|
||||
logScope:
|
||||
topics = "libp2p gossipsub"
|
||||
@@ -60,6 +63,13 @@ declareCounter(
|
||||
labels = ["topic"],
|
||||
)
|
||||
declareGauge(libp2p_gossipsub_received_iwants, "received iwants", labels = ["kind"])
|
||||
declareCounter(
|
||||
libp2p_gossipsub_preamble_saved_iwants,
|
||||
"number of iwant requests avoided by preamble",
|
||||
labels = ["topic"],
|
||||
)
|
||||
|
||||
const MaxHeIsReceiving = 50
|
||||
|
||||
proc grafted*(g: GossipSub, p: PubSubPeer, topic: string) =
|
||||
g.withPeerStats(p.peerId) do(stats: var PeerStats):
|
||||
@@ -277,6 +287,11 @@ proc handlePrune*(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) =
|
||||
for handler in g.routingRecordsHandler:
|
||||
handler(peer.peerId, topic, routingRecords)
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
proc addPossiblePeerToQuery(g: GossipSub, peer: PubSubPeer, messageId: MessageId) =
|
||||
g.ongoingReceives.addPossiblePeerToQuery(messageId, peer)
|
||||
g.ongoingIWantReceives.addPossiblePeerToQuery(messageId, peer)
|
||||
|
||||
proc handleIHave*(
|
||||
g: GossipSub, peer: PubSubPeer, ihaves: seq[ControlIHave]
|
||||
): ControlIWant =
|
||||
@@ -294,6 +309,14 @@ proc handleIHave*(
|
||||
if peer.iHaveBudget <= 0:
|
||||
break
|
||||
elif msgId notin res.messageIDs:
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
if g.ongoingReceives.hasKey(msgId) or
|
||||
g.ongoingIWantReceives.hasKey(msgId):
|
||||
g.addPossiblePeerToQuery(peer, msgId)
|
||||
libp2p_gossipsub_preamble_saved_iwants.inc(
|
||||
labelValues = [ihave.topicID]
|
||||
)
|
||||
continue
|
||||
res.messageIDs.add(msgId)
|
||||
dec peer.iHaveBudget
|
||||
trace "requested message via ihave", messageID = msgId
|
||||
@@ -308,13 +331,15 @@ proc handleIDontWant*(g: GossipSub, peer: PubSubPeer, iDontWants: seq[ControlIWa
|
||||
if peer.iDontWants[0].len >= IDontWantMaxCount:
|
||||
break
|
||||
peer.iDontWants[0].incl(g.salt(messageId))
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
peer.heIsReceivings.del(messageId)
|
||||
g.addPossiblePeerToQuery(peer, messageId)
|
||||
|
||||
proc handleIWant*(
|
||||
g: GossipSub, peer: PubSubPeer, iwants: seq[ControlIWant]
|
||||
): seq[Message] =
|
||||
var
|
||||
messages: seq[Message]
|
||||
invalidRequests = 0
|
||||
): tuple[messages: seq[Message], ids: seq[MessageId]] =
|
||||
var response: tuple[messages: seq[Message], ids: seq[MessageId]]
|
||||
var invalidRequests = 0
|
||||
if peer.score < g.parameters.gossipThreshold:
|
||||
trace "iwant: ignoring low score peer", peer, score = peer.score
|
||||
else:
|
||||
@@ -328,14 +353,101 @@ proc handleIWant*(
|
||||
invalidRequests.inc()
|
||||
if invalidRequests > 20:
|
||||
libp2p_gossipsub_received_iwants.inc(1, labelValues = ["skipped"])
|
||||
return messages
|
||||
return response
|
||||
continue
|
||||
let msg = g.mcache.get(mid).valueOr:
|
||||
libp2p_gossipsub_received_iwants.inc(1, labelValues = ["unknown"])
|
||||
continue
|
||||
libp2p_gossipsub_received_iwants.inc(1, labelValues = ["correct"])
|
||||
messages.add(msg)
|
||||
return messages
|
||||
response.messages.add(msg)
|
||||
response.ids.add(mid)
|
||||
return response
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
proc medianDownloadRate*(p: var HashSet[PubSubPeer]): float =
|
||||
if p.len == 0:
|
||||
return 0
|
||||
|
||||
let vals = p.toSeq().mapIt(it.bandwidthTracking.download.value()).sorted()
|
||||
echo vals
|
||||
let mid = vals.len div 2
|
||||
if vals.len mod 2 == 0:
|
||||
(vals[mid - 1] + vals[mid]) / 2
|
||||
else:
|
||||
vals[mid]
|
||||
|
||||
proc handlePreamble*(
|
||||
g: GossipSub, peer: PubSubPeer, preambles: seq[ControlPreamble]
|
||||
) =
|
||||
let starts = Moment.now()
|
||||
|
||||
for preamble in preambles:
|
||||
dec peer.preambleBudget
|
||||
if peer.preambleBudget <= 0:
|
||||
return
|
||||
if g.hasSeen(g.salt(preamble.messageID)):
|
||||
continue
|
||||
elif peer.heIsSendings.hasKey(preamble.messageID):
|
||||
continue
|
||||
elif g.ongoingReceives.hasKey(preamble.messageID):
|
||||
#TODO: add to conflicts_watch if length is different
|
||||
continue
|
||||
else:
|
||||
peer.heIsSendings[preamble.messageID] = starts
|
||||
var toSendPeers = HashSet[PubSubPeer]()
|
||||
g.mesh.withValue(preamble.topicID, peers):
|
||||
toSendPeers.incl(peers[])
|
||||
toSendPeers.incl(g.subscribedDirectPeers.getOrDefault(preamble.topicID))
|
||||
var peers = toSendPeers.filterIt(it.codec == GossipSubCodec_14)
|
||||
let bytesPerSecond = peer.bandwidthTracking.download.value()
|
||||
let transmissionTimeMs =
|
||||
calculateReceiveTimeMs(preamble.messageLength.int64, bytesPerSecond.int64)
|
||||
let expires = starts + transmissionTimeMs.milliseconds
|
||||
|
||||
#We send imreceiving only if received from mesh members
|
||||
if peer notin peers:
|
||||
if not g.ongoingIWantReceives.hasKey(preamble.messageID):
|
||||
g.ongoingIWantReceives[preamble.messageID] =
|
||||
PreambleInfo.init(preamble, peer, starts, expires)
|
||||
|
||||
trace "preamble: ignoring out of mesh peer", peer
|
||||
continue
|
||||
|
||||
g.ongoingReceives[preamble.messageID] =
|
||||
PreambleInfo.init(preamble, peer, starts, expires)
|
||||
|
||||
#Send imreceiving only if received from faster mesh members
|
||||
if bytesPerSecond >= toSendPeers.medianDownloadRate():
|
||||
g.broadcast(
|
||||
peers,
|
||||
RPCMsg(
|
||||
control: some(
|
||||
ControlMessage(
|
||||
imreceiving:
|
||||
@[
|
||||
ControlIMReceiving(
|
||||
messageID: preamble.messageID,
|
||||
messageLength: preamble.messageLength,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
),
|
||||
isHighPriority = true,
|
||||
)
|
||||
|
||||
proc handleIMReceiving*(
|
||||
g: GossipSub, peer: PubSubPeer, imreceivings: seq[ControlIMReceiving]
|
||||
) =
|
||||
for imreceiving in imreceivings:
|
||||
if peer.heIsReceivings.len > MaxHeIsReceiving:
|
||||
break
|
||||
#Ignore if message length is different
|
||||
g.ongoingReceives.withValue(imreceiving.messageID, pInfo):
|
||||
if pInfo.messageLength != imreceiving.messageLength:
|
||||
continue
|
||||
peer.heIsReceivings[imreceiving.messageID] = imreceiving.messageLength
|
||||
#No need to check mcache. In that case, we might have already transmitted/transmitting
|
||||
|
||||
proc commitMetrics(metrics: var MeshMetrics) =
|
||||
libp2p_gossipsub_low_peers_topics.set(metrics.lowPeersTopics)
|
||||
@@ -710,6 +822,9 @@ proc onHeartbeat(g: GossipSub) =
|
||||
peer.iHaveBudget = IHavePeerBudget
|
||||
peer.pingBudget = PingsPeerBudget
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
peer.preambleBudget = PreamblePeerBudget
|
||||
|
||||
var meshMetrics = MeshMetrics()
|
||||
|
||||
for t in toSeq(g.topics.keys):
|
||||
@@ -777,3 +892,72 @@ proc heartbeat*(g: GossipSub) {.async: (raises: [CancelledError]).} =
|
||||
for trigger in g.heartbeatEvents:
|
||||
trace "firing heartbeat event", instance = cast[int](g)
|
||||
trigger.fire()
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
proc preambleExpirationHeartbeat*(
|
||||
g: GossipSub
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "GossipSub: Preamble Expiration", 200.milliseconds:
|
||||
trace "running preamble expiration heartbeat", instance = cast[int](g)
|
||||
|
||||
while true:
|
||||
var expiredOngoingReceive = g.ongoingReceives.popExpired(Moment.now()).valueOr:
|
||||
break
|
||||
|
||||
if not expiredOngoingReceive.sender.isNil:
|
||||
let sender = expiredOngoingReceive.sender
|
||||
if g.peers.hasKey(sender.peerId):
|
||||
sender.behaviourPenalty += 0.1
|
||||
|
||||
if PullOperation:
|
||||
var possiblePeers = expiredOngoingReceive.possiblePeersToQuery()
|
||||
g.rng.shuffle(possiblePeers)
|
||||
|
||||
var peer: PubSubPeer = nil
|
||||
for peerId in possiblePeers:
|
||||
try:
|
||||
if g.peers.hasKey(peerId) and g.peers[peerId].codec == GossipSubCodec_14:
|
||||
peer = g.peers[peerId]
|
||||
break
|
||||
except KeyError:
|
||||
assert false, "checked with hasKey"
|
||||
|
||||
if peer.isNil:
|
||||
trace "no peer available to send IWANT for an expiredOngoingReceive",
|
||||
messageID = expiredOngoingReceive.messageId
|
||||
continue
|
||||
|
||||
let starts = Moment.now()
|
||||
|
||||
g.broadcast(
|
||||
@[peer],
|
||||
RPCMsg(
|
||||
control: some(
|
||||
ControlMessage(
|
||||
iwant: @[ControlIWant(messageIDs: @[expiredOngoingReceive.messageId])]
|
||||
)
|
||||
)
|
||||
),
|
||||
isHighPriority = true,
|
||||
)
|
||||
|
||||
let bytesPerSecond = peer.bandwidthTracking.download.value()
|
||||
let transmissionTimeMs = calculateReceiveTimeMs(
|
||||
expiredOngoingReceive.messageLength.int64, bytesPerSecond.int64
|
||||
)
|
||||
let expires = starts + transmissionTimeMs.milliseconds
|
||||
|
||||
# Setting new data before reinserting the preamble
|
||||
expiredOngoingReceive.startsAt = starts
|
||||
expiredOngoingReceive.expiresAt = expires
|
||||
expiredOngoingReceive.sender = peer
|
||||
g.ongoingIWantReceives[expiredOngoingReceive.messageId] =
|
||||
expiredOngoingReceive
|
||||
|
||||
while true:
|
||||
let expiredOngoingIWantReceived = g.ongoingIWantReceives.popExpired(
|
||||
Moment.now()
|
||||
).valueOr:
|
||||
break
|
||||
# TODO: use expiredOngoingIWantReceived
|
||||
# TODO: what should we do here?
|
||||
|
||||
@@ -8,16 +8,16 @@ import ../pubsubpeer
|
||||
proc `<`(a, b: PreambleInfo): bool =
|
||||
a.expiresAt < b.expiresAt
|
||||
|
||||
proc init*(_: typedesc[PeerSet]): PeerSet =
|
||||
proc init*(T: typedesc[PeerSet]): T =
|
||||
PeerSet(order: @[], peers: initHashSet[PeerId]())
|
||||
|
||||
proc init*(
|
||||
_: typedesc[PreambleInfo],
|
||||
T: typedesc[PreambleInfo],
|
||||
preamble: ControlPreamble,
|
||||
sender: PubSubPeer,
|
||||
startsAt: Moment,
|
||||
expiresAt: Moment,
|
||||
): PreambleInfo =
|
||||
): T =
|
||||
PreambleInfo(
|
||||
messageId: preamble.messageID,
|
||||
messageLength: preamble.messageLength,
|
||||
|
||||
@@ -9,13 +9,25 @@
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import std/[tables, sets, sequtils]
|
||||
import std/[tables, sets, sequtils, strutils]
|
||||
import ./pubsubpeer, ../../peerid
|
||||
|
||||
export tables, sets
|
||||
|
||||
type PeerTable* = Table[string, HashSet[PubSubPeer]] # topic string to peer map
|
||||
|
||||
proc `$`*(table: PeerTable): string =
|
||||
result.add("PeerTable ")
|
||||
result.add("topics (" & $table.len & ")")
|
||||
|
||||
for topic, peers in table:
|
||||
result.add(" topic: ")
|
||||
result.add($topic)
|
||||
result.add(" peers: ")
|
||||
result.add("(" & $peers.len & ") [")
|
||||
result.add(peers.mapIt($it).join(", "))
|
||||
result.add("]")
|
||||
|
||||
proc hasPeerId*(t: PeerTable, topic: string, peerId: PeerId): bool =
|
||||
if topic in t:
|
||||
try:
|
||||
|
||||
@@ -21,7 +21,11 @@ import
|
||||
../../crypto/crypto,
|
||||
../../protobuf/minprotobuf,
|
||||
../../utility,
|
||||
../../utils/sequninit
|
||||
../../utils/sequninit,
|
||||
./bandwidth
|
||||
|
||||
when defined(libp2p_gossipsub_1_4):
|
||||
import ./bandwidth
|
||||
|
||||
export peerid, connection, deques
|
||||
|
||||
@@ -168,6 +172,9 @@ proc getAgent*(peer: PubSubPeer): string =
|
||||
else:
|
||||
"unknown"
|
||||
|
||||
proc `$`*(p: PubSubPeer): string =
|
||||
$p.peerId
|
||||
|
||||
func hash*(p: PubSubPeer): Hash =
|
||||
p.peerId.hash
|
||||
|
||||
@@ -351,19 +358,23 @@ proc clearSendPriorityQueue(p: PubSubPeer) =
|
||||
value = p.rpcmessagequeue.sendPriorityQueue.len.int64, labelValues = [$p.peerId]
|
||||
)
|
||||
|
||||
proc sendMsgContinue(conn: Connection, msgFut: Future[void]) {.async: (raises: []).} =
|
||||
proc sendMsgContinue(
|
||||
conn: Connection, msgFut: Future[void]
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
# Continuation for a pending `sendMsg` future from below
|
||||
#
|
||||
# conn.close() in exceptions will clean up the send connection. Next time conn is used,
|
||||
# it will be have its close flag set and thus will be recycled.
|
||||
|
||||
try:
|
||||
await msgFut
|
||||
trace "sent pubsub message to remote", conn
|
||||
except CatchableError as exc: # never cancelled
|
||||
# Because we detach the send call from the currently executing task using
|
||||
# asyncSpawn, no exceptions may leak out of it
|
||||
trace "Unable to send to remote", conn, description = exc.msg
|
||||
# Next time sendConn is used, it will be have its close flag set and thus
|
||||
# will be recycled
|
||||
|
||||
await conn.close() # This will clean up the send connection
|
||||
except CancelledError as exc:
|
||||
await conn.close()
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "Unexpected exception in sendMsgContinue", conn, description = exc.msg
|
||||
await conn.close()
|
||||
|
||||
proc sendMsgSlow(p: PubSubPeer, msg: seq[byte]) {.async: (raises: [CancelledError]).} =
|
||||
# Slow path of `sendMsg` where msg is held in memory while send connection is
|
||||
@@ -383,7 +394,7 @@ proc sendMsgSlow(p: PubSubPeer, msg: seq[byte]) {.async: (raises: [CancelledErro
|
||||
|
||||
proc sendMsg(
|
||||
p: PubSubPeer, msg: seq[byte], useCustomConn: bool = false
|
||||
): Future[void] {.async: (raises: []).} =
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
type ConnectionType = enum
|
||||
ctCustom
|
||||
ctSend
|
||||
@@ -403,17 +414,15 @@ proc sendMsg(
|
||||
slowPath = true
|
||||
(nil, ctSlow)
|
||||
|
||||
if not slowPath:
|
||||
trace "sending encoded msg to peer",
|
||||
conntype = $connType, conn = conn, encoded = shortLog(msg)
|
||||
let f = conn.writeLp(msg)
|
||||
if not f.completed():
|
||||
sendMsgContinue(conn, f)
|
||||
else:
|
||||
f
|
||||
else:
|
||||
if slowPath:
|
||||
trace "sending encoded msg to peer via slow path"
|
||||
sendMsgSlow(p, msg)
|
||||
await sendMsgSlow(p, msg)
|
||||
return
|
||||
|
||||
trace "sending encoded msg to peer",
|
||||
conntype = $connType, conn = conn, encoded = shortLog(msg)
|
||||
|
||||
await sendMsgContinue(conn, conn.writeLp(msg))
|
||||
|
||||
proc sendEncoded*(
|
||||
p: PubSubPeer, msg: seq[byte], isHighPriority: bool, useCustomConn: bool = false
|
||||
@@ -474,19 +483,16 @@ iterator splitRPCMsg(
|
||||
## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send.
|
||||
## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes.
|
||||
|
||||
var currentRPCMsg = rpcMsg
|
||||
currentRPCMsg.messages = newSeq[Message]()
|
||||
|
||||
var currentSize = byteSize(currentRPCMsg)
|
||||
var currentRPCMsg = RPCMsg()
|
||||
var currentSize = 0
|
||||
|
||||
for msg in rpcMsg.messages:
|
||||
let msgSize = byteSize(msg)
|
||||
|
||||
# Check if adding the next message will exceed maxSize
|
||||
if float(currentSize + msgSize) * 1.1 > float(maxSize):
|
||||
# Guessing 10% protobuf overhead
|
||||
if currentRPCMsg.messages.len == 0:
|
||||
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
if currentSize + msgSize > maxSize:
|
||||
if msgSize > maxSize:
|
||||
warn "message too big to sent", peer, rpcMsg = shortLog(msg)
|
||||
continue # Skip this message
|
||||
|
||||
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
@@ -498,11 +504,9 @@ iterator splitRPCMsg(
|
||||
currentSize += msgSize
|
||||
|
||||
# Check if there is a non-empty currentRPCMsg left to be added
|
||||
if currentSize > 0 and currentRPCMsg.messages.len > 0:
|
||||
if currentRPCMsg.messages.len > 0:
|
||||
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
yield encodeRpcMsg(currentRPCMsg, anonymize)
|
||||
else:
|
||||
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
|
||||
proc send*(
|
||||
p: PubSubPeer,
|
||||
@@ -538,8 +542,11 @@ proc send*(
|
||||
sendMetrics(msg)
|
||||
encodeRpcMsg(msg, anonymize)
|
||||
|
||||
if encoded.len > p.maxMessageSize and msg.messages.len > 1:
|
||||
for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize):
|
||||
# Messages should not exceed 90% of maxMessageSize. Guessing 10% protobuf overhead.
|
||||
let maxEncodedMsgSize = (p.maxMessageSize * 90) div 100
|
||||
|
||||
if encoded.len > maxEncodedMsgSize and msg.messages.len > 1:
|
||||
for encodedSplitMsg in splitRPCMsg(p, msg, maxEncodedMsgSize, anonymize):
|
||||
asyncSpawn p.sendEncoded(encodedSplitMsg, isHighPriority, useCustomConn)
|
||||
else:
|
||||
# If the message size is within limits, send it as is
|
||||
@@ -553,7 +560,9 @@ proc canAskIWant*(p: PubSubPeer, msgId: MessageId): bool =
|
||||
return true
|
||||
return false
|
||||
|
||||
proc sendNonPriorityTask(p: PubSubPeer) {.async: (raises: [CancelledError]).} =
|
||||
proc sendNonPriorityTask(
|
||||
p: PubSubPeer
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
while true:
|
||||
# we send non-priority messages only if there are no pending priority messages
|
||||
let msg = await p.rpcmessagequeue.nonPriorityQueue.popFirst()
|
||||
|
||||
@@ -41,15 +41,36 @@ func defaultMsgIdProvider*(m: Message): Result[MessageId, ValidationResult] =
|
||||
proc sign*(msg: Message, privateKey: PrivateKey): CryptoResult[seq[byte]] =
|
||||
ok((?privateKey.sign(PubSubPrefix & encodeMessage(msg, false))).getBytes())
|
||||
|
||||
proc extractPublicKey(m: Message): Opt[PublicKey] =
|
||||
var pubkey: PublicKey
|
||||
if m.fromPeer.hasPublicKey() and m.fromPeer.extractPublicKey(pubkey):
|
||||
Opt.some(pubkey)
|
||||
elif m.key.len > 0 and pubkey.init(m.key):
|
||||
# check if peerId extracted from m.key is the same as m.fromPeer
|
||||
let derivedPeerId = PeerId.init(pubkey).valueOr:
|
||||
warn "could not derive peerId from key field"
|
||||
return Opt.none(PublicKey)
|
||||
|
||||
if derivedPeerId != m.fromPeer:
|
||||
warn "peerId derived from msg.key is not the same as msg.fromPeer",
|
||||
derivedPeerId = derivedPeerId, fromPeer = m.fromPeer
|
||||
return Opt.none(PublicKey)
|
||||
Opt.some(pubkey)
|
||||
else:
|
||||
Opt.none(PublicKey)
|
||||
|
||||
proc verify*(m: Message): bool =
|
||||
if m.signature.len > 0 and m.key.len > 0:
|
||||
if m.signature.len > 0:
|
||||
var msg = m
|
||||
msg.signature = @[]
|
||||
msg.key = @[]
|
||||
|
||||
var remote: Signature
|
||||
var key: PublicKey
|
||||
if remote.init(m.signature) and key.init(m.key):
|
||||
let key = m.extractPublicKey().valueOr:
|
||||
warn "could not extract public key", msg = m
|
||||
return false
|
||||
|
||||
if remote.init(m.signature):
|
||||
trace "verifying signature", remoteSignature = remote
|
||||
result = remote.verify(PubSubPrefix & encodeMessage(msg, false), key)
|
||||
|
||||
|
||||
@@ -1,843 +1,3 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
import ./rendezvous/rendezvous
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import tables, sequtils, sugar, sets
|
||||
import metrics except collect
|
||||
import chronos, chronicles, bearssl/rand, stew/[byteutils, objects]
|
||||
import
|
||||
./protocol,
|
||||
../protobuf/minprotobuf,
|
||||
../switch,
|
||||
../routing_record,
|
||||
../utils/heartbeat,
|
||||
../stream/connection,
|
||||
../utils/offsettedseq,
|
||||
../utils/semaphore,
|
||||
../discovery/discoverymngr
|
||||
|
||||
export chronicles
|
||||
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
|
||||
declareCounter(libp2p_rendezvous_register, "number of advertise requests")
|
||||
declareCounter(libp2p_rendezvous_discover, "number of discovery requests")
|
||||
declareGauge(libp2p_rendezvous_registered, "number of registered peers")
|
||||
declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
|
||||
|
||||
const
|
||||
RendezVousCodec* = "/rendezvous/1.0.0"
|
||||
MinimumDuration* = 2.hours
|
||||
MaximumDuration = 72.hours
|
||||
MaximumMessageLen = 1 shl 22 # 4MB
|
||||
MinimumNamespaceLen = 1
|
||||
MaximumNamespaceLen = 255
|
||||
RegistrationLimitPerPeer = 1000
|
||||
DiscoverLimit = 1000'u64
|
||||
SemaphoreDefaultSize = 5
|
||||
|
||||
type
|
||||
MessageType {.pure.} = enum
|
||||
Register = 0
|
||||
RegisterResponse = 1
|
||||
Unregister = 2
|
||||
Discover = 3
|
||||
DiscoverResponse = 4
|
||||
|
||||
ResponseStatus = enum
|
||||
Ok = 0
|
||||
InvalidNamespace = 100
|
||||
InvalidSignedPeerRecord = 101
|
||||
InvalidTTL = 102
|
||||
InvalidCookie = 103
|
||||
NotAuthorized = 200
|
||||
InternalError = 300
|
||||
Unavailable = 400
|
||||
|
||||
Cookie = object
|
||||
offset: uint64
|
||||
ns: Opt[string]
|
||||
|
||||
Register = object
|
||||
ns: string
|
||||
signedPeerRecord: seq[byte]
|
||||
ttl: Opt[uint64] # in seconds
|
||||
|
||||
RegisterResponse = object
|
||||
status: ResponseStatus
|
||||
text: Opt[string]
|
||||
ttl: Opt[uint64] # in seconds
|
||||
|
||||
Unregister = object
|
||||
ns: string
|
||||
|
||||
Discover = object
|
||||
ns: Opt[string]
|
||||
limit: Opt[uint64]
|
||||
cookie: Opt[seq[byte]]
|
||||
|
||||
DiscoverResponse = object
|
||||
registrations: seq[Register]
|
||||
cookie: Opt[seq[byte]]
|
||||
status: ResponseStatus
|
||||
text: Opt[string]
|
||||
|
||||
Message = object
|
||||
msgType: MessageType
|
||||
register: Opt[Register]
|
||||
registerResponse: Opt[RegisterResponse]
|
||||
unregister: Opt[Unregister]
|
||||
discover: Opt[Discover]
|
||||
discoverResponse: Opt[DiscoverResponse]
|
||||
|
||||
proc encode(c: Cookie): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, c.offset)
|
||||
if c.ns.isSome():
|
||||
result.write(2, c.ns.get())
|
||||
result.finish()
|
||||
|
||||
proc encode(r: Register): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, r.ns)
|
||||
result.write(2, r.signedPeerRecord)
|
||||
r.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode(rr: RegisterResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, rr.status.uint)
|
||||
rr.text.withValue(text):
|
||||
result.write(2, text)
|
||||
rr.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode(u: Unregister): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, u.ns)
|
||||
result.finish()
|
||||
|
||||
proc encode(d: Discover): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
if d.ns.isSome():
|
||||
result.write(1, d.ns.get())
|
||||
d.limit.withValue(limit):
|
||||
result.write(2, limit)
|
||||
d.cookie.withValue(cookie):
|
||||
result.write(3, cookie)
|
||||
result.finish()
|
||||
|
||||
proc encode(dr: DiscoverResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
for reg in dr.registrations:
|
||||
result.write(1, reg.encode())
|
||||
dr.cookie.withValue(cookie):
|
||||
result.write(2, cookie)
|
||||
result.write(3, dr.status.uint)
|
||||
dr.text.withValue(text):
|
||||
result.write(4, text)
|
||||
result.finish()
|
||||
|
||||
proc encode(msg: Message): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, msg.msgType.uint)
|
||||
msg.register.withValue(register):
|
||||
result.write(2, register.encode())
|
||||
msg.registerResponse.withValue(registerResponse):
|
||||
result.write(3, registerResponse.encode())
|
||||
msg.unregister.withValue(unregister):
|
||||
result.write(4, unregister.encode())
|
||||
msg.discover.withValue(discover):
|
||||
result.write(5, discover.encode())
|
||||
msg.discoverResponse.withValue(discoverResponse):
|
||||
result.write(6, discoverResponse.encode())
|
||||
result.finish()
|
||||
|
||||
proc decode(_: typedesc[Cookie], buf: seq[byte]): Opt[Cookie] =
|
||||
var
|
||||
c: Cookie
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, c.offset)
|
||||
r2 = pb.getField(2, ns)
|
||||
if r1.isErr() or r2.isErr():
|
||||
return Opt.none(Cookie)
|
||||
if r2.get(false):
|
||||
c.ns = Opt.some(ns)
|
||||
Opt.some(c)
|
||||
|
||||
proc decode(_: typedesc[Register], buf: seq[byte]): Opt[Register] =
|
||||
var
|
||||
r: Register
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, r.ns)
|
||||
r2 = pb.getRequiredField(2, r.signedPeerRecord)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr():
|
||||
return Opt.none(Register)
|
||||
if r3.get(false):
|
||||
r.ttl = Opt.some(ttl)
|
||||
Opt.some(r)
|
||||
|
||||
proc decode(_: typedesc[RegisterResponse], buf: seq[byte]): Opt[RegisterResponse] =
|
||||
var
|
||||
rr: RegisterResponse
|
||||
statusOrd: uint
|
||||
text: string
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, statusOrd)
|
||||
r2 = pb.getField(2, text)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr() or
|
||||
not checkedEnumAssign(rr.status, statusOrd):
|
||||
return Opt.none(RegisterResponse)
|
||||
if r2.get(false):
|
||||
rr.text = Opt.some(text)
|
||||
if r3.get(false):
|
||||
rr.ttl = Opt.some(ttl)
|
||||
Opt.some(rr)
|
||||
|
||||
proc decode(_: typedesc[Unregister], buf: seq[byte]): Opt[Unregister] =
|
||||
var u: Unregister
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, u.ns)
|
||||
if r1.isErr():
|
||||
return Opt.none(Unregister)
|
||||
Opt.some(u)
|
||||
|
||||
proc decode(_: typedesc[Discover], buf: seq[byte]): Opt[Discover] =
|
||||
var
|
||||
d: Discover
|
||||
limit: uint64
|
||||
cookie: seq[byte]
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getField(1, ns)
|
||||
r2 = pb.getField(2, limit)
|
||||
r3 = pb.getField(3, cookie)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr:
|
||||
return Opt.none(Discover)
|
||||
if r1.get(false):
|
||||
d.ns = Opt.some(ns)
|
||||
if r2.get(false):
|
||||
d.limit = Opt.some(limit)
|
||||
if r3.get(false):
|
||||
d.cookie = Opt.some(cookie)
|
||||
Opt.some(d)
|
||||
|
||||
proc decode(_: typedesc[DiscoverResponse], buf: seq[byte]): Opt[DiscoverResponse] =
|
||||
var
|
||||
dr: DiscoverResponse
|
||||
registrations: seq[seq[byte]]
|
||||
cookie: seq[byte]
|
||||
statusOrd: uint
|
||||
text: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRepeatedField(1, registrations)
|
||||
r2 = pb.getField(2, cookie)
|
||||
r3 = pb.getRequiredField(3, statusOrd)
|
||||
r4 = pb.getField(4, text)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or
|
||||
not checkedEnumAssign(dr.status, statusOrd):
|
||||
return Opt.none(DiscoverResponse)
|
||||
for reg in registrations:
|
||||
var r: Register
|
||||
let regOpt = Register.decode(reg).valueOr:
|
||||
return
|
||||
dr.registrations.add(regOpt)
|
||||
if r2.get(false):
|
||||
dr.cookie = Opt.some(cookie)
|
||||
if r4.get(false):
|
||||
dr.text = Opt.some(text)
|
||||
Opt.some(dr)
|
||||
|
||||
proc decode(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
|
||||
var
|
||||
msg: Message
|
||||
statusOrd: uint
|
||||
pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer
|
||||
let pb = initProtoBuffer(buf)
|
||||
|
||||
?pb.getRequiredField(1, statusOrd).toOpt
|
||||
if not checkedEnumAssign(msg.msgType, statusOrd):
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(2, pbr).optValue:
|
||||
msg.register = Register.decode(pbr.buffer)
|
||||
if msg.register.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(3, pbrr).optValue:
|
||||
msg.registerResponse = RegisterResponse.decode(pbrr.buffer)
|
||||
if msg.registerResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(4, pbu).optValue:
|
||||
msg.unregister = Unregister.decode(pbu.buffer)
|
||||
if msg.unregister.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(5, pbd).optValue:
|
||||
msg.discover = Discover.decode(pbd.buffer)
|
||||
if msg.discover.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(6, pbdr).optValue:
|
||||
msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer)
|
||||
if msg.discoverResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
Opt.some(msg)
|
||||
|
||||
type
|
||||
RendezVousError* = object of DiscoveryError
|
||||
RegisteredData = object
|
||||
expiration: Moment
|
||||
peerId: PeerId
|
||||
data: Register
|
||||
|
||||
RendezVous* = ref object of LPProtocol
|
||||
# Registered needs to be an offsetted sequence
|
||||
# because we need stable index for the cookies.
|
||||
registered: OffsettedSeq[RegisteredData]
|
||||
# Namespaces is a table whose key is a salted namespace and
|
||||
# the value is the index sequence corresponding to this
|
||||
# namespace in the offsettedqueue.
|
||||
namespaces: Table[string, seq[int]]
|
||||
rng: ref HmacDrbgContext
|
||||
salt: string
|
||||
defaultDT: Moment
|
||||
registerDeletionLoop: Future[void]
|
||||
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
|
||||
# + make the heartbeat sleep duration "smarter"
|
||||
sema: AsyncSemaphore
|
||||
peers: seq[PeerId]
|
||||
cookiesSaved: Table[PeerId, Table[string, seq[byte]]]
|
||||
switch: Switch
|
||||
minDuration: Duration
|
||||
maxDuration: Duration
|
||||
minTTL: uint64
|
||||
maxTTL: uint64
|
||||
|
||||
proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] =
|
||||
if spr.len == 0:
|
||||
return err("Empty peer record")
|
||||
let signedEnv = ?SignedPeerRecord.decode(spr).mapErr(x => $x)
|
||||
if signedEnv.data.peerId != peerId:
|
||||
return err("Bad Peer ID")
|
||||
return ok()
|
||||
|
||||
proc sendRegisterResponse(
|
||||
conn: Connection, ttl: uint64
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: Ok, ttl: Opt.some(ttl))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendRegisterResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponse(
|
||||
conn: Connection, s: seq[Register], cookie: Cookie
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(
|
||||
DiscoverResponse(
|
||||
status: Ok, registrations: s, cookie: Opt.some(cookie.encode().buffer)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(DiscoverResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
|
||||
let n = Moment.now()
|
||||
for data in rdv.registered:
|
||||
if data.peerId == peerId and data.expiration > n:
|
||||
result.inc()
|
||||
|
||||
proc save(
|
||||
rdv: RendezVous, ns: string, peerId: PeerId, r: Register, update: bool = true
|
||||
) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
discard rdv.namespaces.hasKeyOrPut(nsSalted, newSeq[int]())
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == peerId:
|
||||
if update == false:
|
||||
return
|
||||
rdv.registered[index].expiration = rdv.defaultDT
|
||||
rdv.registered.add(
|
||||
RegisteredData(
|
||||
peerId: peerId,
|
||||
expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds,
|
||||
data: r,
|
||||
)
|
||||
)
|
||||
rdv.namespaces[nsSalted].add(rdv.registered.high)
|
||||
# rdv.registerEvent.fire()
|
||||
except KeyError as e:
|
||||
doAssert false, "Should have key: " & e.msg
|
||||
|
||||
proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] =
|
||||
trace "Received Register", peerId = conn.peerId, ns = r.ns
|
||||
libp2p_rendezvous_register.inc()
|
||||
if r.ns.len < MinimumNamespaceLen or r.ns.len > MaximumNamespaceLen:
|
||||
return conn.sendRegisterResponseError(InvalidNamespace)
|
||||
let ttl = r.ttl.get(rdv.minTTL)
|
||||
if ttl < rdv.minTTL or ttl > rdv.maxTTL:
|
||||
return conn.sendRegisterResponseError(InvalidTTL)
|
||||
let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId)
|
||||
if pr.isErr():
|
||||
return conn.sendRegisterResponseError(InvalidSignedPeerRecord, pr.error())
|
||||
if rdv.countRegister(conn.peerId) >= RegistrationLimitPerPeer:
|
||||
return conn.sendRegisterResponseError(NotAuthorized, "Registration limit reached")
|
||||
rdv.save(r.ns, conn.peerId, r)
|
||||
libp2p_rendezvous_registered.inc()
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
conn.sendRegisterResponse(ttl)
|
||||
|
||||
proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
|
||||
trace "Received Unregister", peerId = conn.peerId, ns = u.ns
|
||||
let nsSalted = u.ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == conn.peerId:
|
||||
rdv.registered[index].expiration = rdv.defaultDT
|
||||
libp2p_rendezvous_registered.dec()
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc discover(
|
||||
rdv: RendezVous, conn: Connection, d: Discover
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
trace "Received Discover", peerId = conn.peerId, ns = d.ns
|
||||
libp2p_rendezvous_discover.inc()
|
||||
if d.ns.isSome() and d.ns.get().len > MaximumNamespaceLen:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
var limit = min(DiscoverLimit, d.limit.get(DiscoverLimit))
|
||||
var cookie =
|
||||
if d.cookie.isSome():
|
||||
try:
|
||||
Cookie.decode(d.cookie.tryGet()).tryGet()
|
||||
except CatchableError:
|
||||
await conn.sendDiscoverResponseError(InvalidCookie)
|
||||
return
|
||||
else:
|
||||
Cookie(offset: rdv.registered.low().uint64 - 1)
|
||||
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get() or
|
||||
cookie.offset < rdv.registered.low().uint64 or
|
||||
cookie.offset > rdv.registered.high().uint64:
|
||||
cookie = Cookie(offset: rdv.registered.low().uint64 - 1)
|
||||
let namespaces =
|
||||
if d.ns.isSome():
|
||||
try:
|
||||
rdv.namespaces[d.ns.get() & rdv.salt]
|
||||
except KeyError:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
else:
|
||||
toSeq(max(cookie.offset.int, rdv.registered.offset) .. rdv.registered.high())
|
||||
if namespaces.len() == 0:
|
||||
await conn.sendDiscoverResponse(@[], Cookie())
|
||||
return
|
||||
var offset = namespaces[^1]
|
||||
let n = Moment.now()
|
||||
var s = collect(newSeq()):
|
||||
for index in namespaces:
|
||||
var reg = rdv.registered[index]
|
||||
if limit == 0:
|
||||
offset = index
|
||||
break
|
||||
if reg.expiration < n or index.uint64 <= cookie.offset:
|
||||
continue
|
||||
limit.dec()
|
||||
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
|
||||
reg.data
|
||||
rdv.rng.shuffle(s)
|
||||
await conn.sendDiscoverResponse(s, Cookie(offset: offset.uint64, ns: d.ns))
|
||||
|
||||
proc advertisePeer(
|
||||
rdv: RendezVous, peer: PeerId, msg: seq[byte]
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
proc advertiseWrap() {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg)
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msgRecv = Message.decode(buf).tryGet()
|
||||
if msgRecv.msgType != MessageType.RegisterResponse:
|
||||
trace "Unexpected register response", peer, msgType = msgRecv.msgType
|
||||
elif msgRecv.registerResponse.tryGet().status != ResponseStatus.Ok:
|
||||
trace "Refuse to register", peer, response = msgRecv.registerResponse
|
||||
else:
|
||||
trace "Successfully registered", peer, response = msgRecv.registerResponse
|
||||
except CatchableError as exc:
|
||||
trace "exception in the advertise", description = exc.msg
|
||||
finally:
|
||||
rdv.sema.release()
|
||||
|
||||
await rdv.sema.acquire()
|
||||
await advertiseWrap()
|
||||
|
||||
proc advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId]
|
||||
) {.async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
if ttl < rdv.minDuration or ttl > rdv.maxDuration:
|
||||
raise newException(AdvertiseError, "Invalid time to live: " & $ttl)
|
||||
|
||||
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
|
||||
raise newException(AdvertiseError, "Wrong Signed Peer Record")
|
||||
|
||||
let
|
||||
r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
|
||||
msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
|
||||
|
||||
rdv.save(ns, rdv.switch.peerInfo.peerId, r)
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peers:
|
||||
trace "Send Advertise", peerId = peer, ns
|
||||
rdv.advertisePeer(peer, msg.buffer).withTimeout(5.seconds)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
method advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration
|
||||
) {.base, async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
await rdv.advertise(ns, ttl, rdv.peers)
|
||||
|
||||
proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
|
||||
let
|
||||
nsSalted = ns & rdv.salt
|
||||
n = Moment.now()
|
||||
try:
|
||||
collect(newSeq()):
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].expiration > n:
|
||||
let res = SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).valueOr:
|
||||
continue
|
||||
res.data
|
||||
except KeyError as exc:
|
||||
@[]
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int, peers: seq[PeerId]
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
var
|
||||
s: Table[PeerId, (PeerRecord, Register)]
|
||||
limit: uint64
|
||||
d = Discover(ns: ns)
|
||||
|
||||
if l <= 0 or l > DiscoverLimit.int:
|
||||
raise newException(AdvertiseError, "Invalid limit")
|
||||
if ns.isSome() and ns.get().len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
limit = l.uint64
|
||||
proc requestPeer(
|
||||
peer: PeerId
|
||||
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
d.limit = Opt.some(limit)
|
||||
d.cookie =
|
||||
if ns.isSome():
|
||||
try:
|
||||
Opt.some(rdv.cookiesSaved[peer][ns.get()])
|
||||
except KeyError, CatchableError:
|
||||
Opt.none(seq[byte])
|
||||
else:
|
||||
Opt.none(seq[byte])
|
||||
await conn.writeLp(
|
||||
encode(Message(msgType: MessageType.Discover, discover: Opt.some(d))).buffer
|
||||
)
|
||||
let
|
||||
buf = await conn.readLp(MaximumMessageLen)
|
||||
msgRcv = Message.decode(buf).valueOr:
|
||||
debug "Message undecodable"
|
||||
return
|
||||
if msgRcv.msgType != MessageType.DiscoverResponse:
|
||||
debug "Unexpected discover response", msgType = msgRcv.msgType
|
||||
return
|
||||
let resp = msgRcv.discoverResponse.valueOr:
|
||||
debug "Discover response is empty"
|
||||
return
|
||||
if resp.status != ResponseStatus.Ok:
|
||||
trace "Cannot discover", ns, status = resp.status, text = resp.text
|
||||
return
|
||||
resp.cookie.withValue(cookie):
|
||||
if ns.isSome:
|
||||
let namespace = ns.get()
|
||||
if cookie.len() < 1000 and
|
||||
rdv.cookiesSaved.hasKeyOrPut(peer, {namespace: cookie}.toTable()):
|
||||
try:
|
||||
rdv.cookiesSaved[peer][namespace] = cookie
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKeyOrPut"
|
||||
for r in resp.registrations:
|
||||
if limit == 0:
|
||||
return
|
||||
let ttl = r.ttl.get(rdv.maxTTL + 1)
|
||||
if ttl > rdv.maxTTL:
|
||||
continue
|
||||
let
|
||||
spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr:
|
||||
continue
|
||||
pr = spr.data
|
||||
if s.hasKey(pr.peerId):
|
||||
let (prSaved, rSaved) =
|
||||
try:
|
||||
s[pr.peerId]
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKey"
|
||||
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or
|
||||
prSaved.seqNo < pr.seqNo:
|
||||
s[pr.peerId] = (pr, r)
|
||||
else:
|
||||
s[pr.peerId] = (pr, r)
|
||||
limit.dec()
|
||||
if ns.isSome():
|
||||
for (_, r) in s.values():
|
||||
rdv.save(ns.get(), peer, r, false)
|
||||
|
||||
for peer in peers:
|
||||
if limit == 0:
|
||||
break
|
||||
if RendezVousCodec notin rdv.switch.peerStore[ProtoBook][peer]:
|
||||
continue
|
||||
try:
|
||||
trace "Send Request", peerId = peer, ns
|
||||
await peer.requestPeer()
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
except DialFailedError as e:
|
||||
trace "failed to dial a peer", description = e.msg
|
||||
except LPStreamError as e:
|
||||
trace "failed to communicate with a peer", description = e.msg
|
||||
return toSeq(s.values()).mapIt(it[0])
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(ns, l, rdv.peers)
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(Opt.none(string), l, rdv.peers)
|
||||
|
||||
proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
|
||||
rdv.registered[index].expiration = rdv.defaultDT
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string, peerIds: seq[PeerId]
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(RendezVousError, "Invalid namespace")
|
||||
|
||||
let msg = encode(
|
||||
Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns)))
|
||||
)
|
||||
|
||||
proc unsubscribePeer(peerId: PeerId) {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peerId, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg.buffer)
|
||||
except CatchableError as exc:
|
||||
trace "exception while unsubscribing", description = exc.msg
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peerIds:
|
||||
unsubscribePeer(peer)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
rdv.unsubscribeLocally(ns)
|
||||
|
||||
await rdv.unsubscribe(ns, rdv.peers)
|
||||
|
||||
proc setup*(rdv: RendezVous, switch: Switch) =
|
||||
rdv.switch = switch
|
||||
proc handlePeer(
|
||||
peerId: PeerId, event: PeerEvent
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
if event.kind == PeerEventKind.Joined:
|
||||
rdv.peers.add(peerId)
|
||||
elif event.kind == PeerEventKind.Left:
|
||||
rdv.peers.keepItIf(it != peerId)
|
||||
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Joined)
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Left)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
if minDuration < 1.minutes:
|
||||
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
|
||||
|
||||
if maxDuration > 72.hours:
|
||||
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
|
||||
|
||||
if minDuration >= maxDuration:
|
||||
raise newException(RendezVousError, "Minimum TTL longer than maximum")
|
||||
|
||||
let
|
||||
minTTL = minDuration.seconds.uint64
|
||||
maxTTL = maxDuration.seconds.uint64
|
||||
|
||||
let rdv = T(
|
||||
rng: rng,
|
||||
salt: string.fromBytes(generateBytes(rng[], 8)),
|
||||
registered: initOffsettedSeq[RegisteredData](1),
|
||||
defaultDT: Moment.now() - 1.days,
|
||||
#registerEvent: newAsyncEvent(),
|
||||
sema: newAsyncSemaphore(SemaphoreDefaultSize),
|
||||
minDuration: minDuration,
|
||||
maxDuration: maxDuration,
|
||||
minTTL: minTTL,
|
||||
maxTTL: maxTTL,
|
||||
)
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
proc handleStream(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msg = Message.decode(buf).tryGet()
|
||||
case msg.msgType
|
||||
of MessageType.Register:
|
||||
await rdv.register(conn, msg.register.tryGet())
|
||||
of MessageType.RegisterResponse:
|
||||
trace "Got an unexpected Register Response", response = msg.registerResponse
|
||||
of MessageType.Unregister:
|
||||
rdv.unregister(conn, msg.unregister.tryGet())
|
||||
of MessageType.Discover:
|
||||
await rdv.discover(conn, msg.discover.tryGet())
|
||||
of MessageType.DiscoverResponse:
|
||||
trace "Got an unexpected Discover Response", response = msg.discoverResponse
|
||||
except CancelledError as exc:
|
||||
trace "cancelled rendezvous handler"
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "exception in rendezvous handler", description = exc.msg
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
rdv.handler = handleStream
|
||||
rdv.codec = RendezVousCodec
|
||||
return rdv
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
switch: Switch,
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
let rdv = T.new(rng, minDuration, maxDuration)
|
||||
rdv.setup(switch)
|
||||
return rdv
|
||||
|
||||
proc deletesRegister(
|
||||
rdv: RendezVous, interval = 1.minutes
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "Register timeout", interval:
|
||||
let n = Moment.now()
|
||||
var total = 0
|
||||
rdv.registered.flushIfIt(it.expiration < n)
|
||||
for data in rdv.namespaces.mvalues():
|
||||
data.keepItIf(it >= rdv.registered.offset)
|
||||
total += data.len
|
||||
libp2p_rendezvous_registered.set(int64(total))
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
|
||||
method start*(
|
||||
rdv: RendezVous
|
||||
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if not rdv.registerDeletionLoop.isNil:
|
||||
warn "Starting rendezvous twice"
|
||||
return fut
|
||||
rdv.registerDeletionLoop = rdv.deletesRegister()
|
||||
rdv.started = true
|
||||
fut
|
||||
|
||||
method stop*(rdv: RendezVous): Future[void] {.async: (raises: [], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if rdv.registerDeletionLoop.isNil:
|
||||
warn "Stopping rendezvous without starting it"
|
||||
return fut
|
||||
rdv.started = false
|
||||
rdv.registerDeletionLoop.cancel()
|
||||
rdv.registerDeletionLoop = nil
|
||||
fut
|
||||
export rendezvous
|
||||
|
||||
275
libp2p/protocols/rendezvous/protobuf.nim
Normal file
275
libp2p/protocols/rendezvous/protobuf.nim
Normal file
@@ -0,0 +1,275 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import results
|
||||
import stew/objects
|
||||
import ../../protobuf/minprotobuf
|
||||
|
||||
type
|
||||
MessageType* {.pure.} = enum
|
||||
Register = 0
|
||||
RegisterResponse = 1
|
||||
Unregister = 2
|
||||
Discover = 3
|
||||
DiscoverResponse = 4
|
||||
|
||||
ResponseStatus* = enum
|
||||
Ok = 0
|
||||
InvalidNamespace = 100
|
||||
InvalidSignedPeerRecord = 101
|
||||
InvalidTTL = 102
|
||||
InvalidCookie = 103
|
||||
NotAuthorized = 200
|
||||
InternalError = 300
|
||||
Unavailable = 400
|
||||
|
||||
Cookie* = object
|
||||
offset*: uint64
|
||||
ns*: Opt[string]
|
||||
|
||||
Register* = object
|
||||
ns*: string
|
||||
signedPeerRecord*: seq[byte]
|
||||
ttl*: Opt[uint64] # in seconds
|
||||
|
||||
RegisterResponse* = object
|
||||
status*: ResponseStatus
|
||||
text*: Opt[string]
|
||||
ttl*: Opt[uint64] # in seconds
|
||||
|
||||
Unregister* = object
|
||||
ns*: string
|
||||
|
||||
Discover* = object
|
||||
ns*: Opt[string]
|
||||
limit*: Opt[uint64]
|
||||
cookie*: Opt[seq[byte]]
|
||||
|
||||
DiscoverResponse* = object
|
||||
registrations*: seq[Register]
|
||||
cookie*: Opt[seq[byte]]
|
||||
status*: ResponseStatus
|
||||
text*: Opt[string]
|
||||
|
||||
Message* = object
|
||||
msgType*: MessageType
|
||||
register*: Opt[Register]
|
||||
registerResponse*: Opt[RegisterResponse]
|
||||
unregister*: Opt[Unregister]
|
||||
discover*: Opt[Discover]
|
||||
discoverResponse*: Opt[DiscoverResponse]
|
||||
|
||||
proc encode*(c: Cookie): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, c.offset)
|
||||
if c.ns.isSome():
|
||||
result.write(2, c.ns.get())
|
||||
result.finish()
|
||||
|
||||
proc encode*(r: Register): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, r.ns)
|
||||
result.write(2, r.signedPeerRecord)
|
||||
r.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode*(rr: RegisterResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, rr.status.uint)
|
||||
rr.text.withValue(text):
|
||||
result.write(2, text)
|
||||
rr.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode*(u: Unregister): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, u.ns)
|
||||
result.finish()
|
||||
|
||||
proc encode*(d: Discover): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
if d.ns.isSome():
|
||||
result.write(1, d.ns.get())
|
||||
d.limit.withValue(limit):
|
||||
result.write(2, limit)
|
||||
d.cookie.withValue(cookie):
|
||||
result.write(3, cookie)
|
||||
result.finish()
|
||||
|
||||
proc encode*(dr: DiscoverResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
for reg in dr.registrations:
|
||||
result.write(1, reg.encode())
|
||||
dr.cookie.withValue(cookie):
|
||||
result.write(2, cookie)
|
||||
result.write(3, dr.status.uint)
|
||||
dr.text.withValue(text):
|
||||
result.write(4, text)
|
||||
result.finish()
|
||||
|
||||
proc encode*(msg: Message): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, msg.msgType.uint)
|
||||
msg.register.withValue(register):
|
||||
result.write(2, register.encode())
|
||||
msg.registerResponse.withValue(registerResponse):
|
||||
result.write(3, registerResponse.encode())
|
||||
msg.unregister.withValue(unregister):
|
||||
result.write(4, unregister.encode())
|
||||
msg.discover.withValue(discover):
|
||||
result.write(5, discover.encode())
|
||||
msg.discoverResponse.withValue(discoverResponse):
|
||||
result.write(6, discoverResponse.encode())
|
||||
result.finish()
|
||||
|
||||
proc decode*(_: typedesc[Cookie], buf: seq[byte]): Opt[Cookie] =
|
||||
var
|
||||
c: Cookie
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, c.offset)
|
||||
r2 = pb.getField(2, ns)
|
||||
if r1.isErr() or r2.isErr():
|
||||
return Opt.none(Cookie)
|
||||
if r2.get(false):
|
||||
c.ns = Opt.some(ns)
|
||||
Opt.some(c)
|
||||
|
||||
proc decode*(_: typedesc[Register], buf: seq[byte]): Opt[Register] =
|
||||
var
|
||||
r: Register
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, r.ns)
|
||||
r2 = pb.getRequiredField(2, r.signedPeerRecord)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr():
|
||||
return Opt.none(Register)
|
||||
if r3.get(false):
|
||||
r.ttl = Opt.some(ttl)
|
||||
Opt.some(r)
|
||||
|
||||
proc decode*(_: typedesc[RegisterResponse], buf: seq[byte]): Opt[RegisterResponse] =
|
||||
var
|
||||
rr: RegisterResponse
|
||||
statusOrd: uint
|
||||
text: string
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, statusOrd)
|
||||
r2 = pb.getField(2, text)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr() or
|
||||
not checkedEnumAssign(rr.status, statusOrd):
|
||||
return Opt.none(RegisterResponse)
|
||||
if r2.get(false):
|
||||
rr.text = Opt.some(text)
|
||||
if r3.get(false):
|
||||
rr.ttl = Opt.some(ttl)
|
||||
Opt.some(rr)
|
||||
|
||||
proc decode*(_: typedesc[Unregister], buf: seq[byte]): Opt[Unregister] =
|
||||
var u: Unregister
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, u.ns)
|
||||
if r1.isErr():
|
||||
return Opt.none(Unregister)
|
||||
Opt.some(u)
|
||||
|
||||
proc decode*(_: typedesc[Discover], buf: seq[byte]): Opt[Discover] =
|
||||
var
|
||||
d: Discover
|
||||
limit: uint64
|
||||
cookie: seq[byte]
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getField(1, ns)
|
||||
r2 = pb.getField(2, limit)
|
||||
r3 = pb.getField(3, cookie)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr:
|
||||
return Opt.none(Discover)
|
||||
if r1.get(false):
|
||||
d.ns = Opt.some(ns)
|
||||
if r2.get(false):
|
||||
d.limit = Opt.some(limit)
|
||||
if r3.get(false):
|
||||
d.cookie = Opt.some(cookie)
|
||||
Opt.some(d)
|
||||
|
||||
proc decode*(_: typedesc[DiscoverResponse], buf: seq[byte]): Opt[DiscoverResponse] =
|
||||
var
|
||||
dr: DiscoverResponse
|
||||
registrations: seq[seq[byte]]
|
||||
cookie: seq[byte]
|
||||
statusOrd: uint
|
||||
text: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRepeatedField(1, registrations)
|
||||
r2 = pb.getField(2, cookie)
|
||||
r3 = pb.getRequiredField(3, statusOrd)
|
||||
r4 = pb.getField(4, text)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or
|
||||
not checkedEnumAssign(dr.status, statusOrd):
|
||||
return Opt.none(DiscoverResponse)
|
||||
for reg in registrations:
|
||||
var r: Register
|
||||
let regOpt = Register.decode(reg).valueOr:
|
||||
return
|
||||
dr.registrations.add(regOpt)
|
||||
if r2.get(false):
|
||||
dr.cookie = Opt.some(cookie)
|
||||
if r4.get(false):
|
||||
dr.text = Opt.some(text)
|
||||
Opt.some(dr)
|
||||
|
||||
proc decode*(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
|
||||
var
|
||||
msg: Message
|
||||
statusOrd: uint
|
||||
pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer
|
||||
let pb = initProtoBuffer(buf)
|
||||
|
||||
?pb.getRequiredField(1, statusOrd).toOpt
|
||||
if not checkedEnumAssign(msg.msgType, statusOrd):
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(2, pbr).optValue:
|
||||
msg.register = Register.decode(pbr.buffer)
|
||||
if msg.register.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(3, pbrr).optValue:
|
||||
msg.registerResponse = RegisterResponse.decode(pbrr.buffer)
|
||||
if msg.registerResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(4, pbu).optValue:
|
||||
msg.unregister = Unregister.decode(pbu.buffer)
|
||||
if msg.unregister.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(5, pbd).optValue:
|
||||
msg.discover = Discover.decode(pbd.buffer)
|
||||
if msg.discover.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(6, pbdr).optValue:
|
||||
msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer)
|
||||
if msg.discoverResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
Opt.some(msg)
|
||||
589
libp2p/protocols/rendezvous/rendezvous.nim
Normal file
589
libp2p/protocols/rendezvous/rendezvous.nim
Normal file
@@ -0,0 +1,589 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import tables, sequtils, sugar, sets
|
||||
import metrics except collect
|
||||
import chronos, chronicles, bearssl/rand, stew/[byteutils, objects]
|
||||
import
|
||||
./protobuf,
|
||||
../protocol,
|
||||
../../protobuf/minprotobuf,
|
||||
../../switch,
|
||||
../../routing_record,
|
||||
../../utils/heartbeat,
|
||||
../../stream/connection,
|
||||
../../utils/offsettedseq,
|
||||
../../utils/semaphore,
|
||||
../../discovery/discoverymngr
|
||||
|
||||
export chronicles
|
||||
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
|
||||
declareCounter(libp2p_rendezvous_register, "number of advertise requests")
|
||||
declareCounter(libp2p_rendezvous_discover, "number of discovery requests")
|
||||
declareGauge(libp2p_rendezvous_registered, "number of registered peers")
|
||||
declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
|
||||
|
||||
const
|
||||
RendezVousCodec* = "/rendezvous/1.0.0"
|
||||
# Default minimum TTL per libp2p spec
|
||||
MinimumDuration* = 2.hours
|
||||
# Lower validation limit to accommodate Waku requirements
|
||||
MinimumAcceptedDuration = 1.minutes
|
||||
MaximumDuration = 72.hours
|
||||
MaximumMessageLen = 1 shl 22 # 4MB
|
||||
MinimumNamespaceLen = 1
|
||||
MaximumNamespaceLen = 255
|
||||
RegistrationLimitPerPeer* = 1000
|
||||
DiscoverLimit = 1000'u64
|
||||
SemaphoreDefaultSize = 5
|
||||
|
||||
type
|
||||
RendezVousError* = object of DiscoveryError
|
||||
RegisteredData = object
|
||||
expiration*: Moment
|
||||
peerId*: PeerId
|
||||
data*: Register
|
||||
|
||||
RendezVous* = ref object of LPProtocol
|
||||
# Registered needs to be an offsetted sequence
|
||||
# because we need stable index for the cookies.
|
||||
registered*: OffsettedSeq[RegisteredData]
|
||||
# Namespaces is a table whose key is a salted namespace and
|
||||
# the value is the index sequence corresponding to this
|
||||
# namespace in the offsettedqueue.
|
||||
namespaces*: Table[string, seq[int]]
|
||||
rng: ref HmacDrbgContext
|
||||
salt: string
|
||||
expiredDT: Moment
|
||||
registerDeletionLoop: Future[void]
|
||||
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
|
||||
# + make the heartbeat sleep duration "smarter"
|
||||
sema: AsyncSemaphore
|
||||
peers: seq[PeerId]
|
||||
cookiesSaved*: Table[PeerId, Table[string, seq[byte]]]
|
||||
switch: Switch
|
||||
minDuration: Duration
|
||||
maxDuration: Duration
|
||||
minTTL: uint64
|
||||
maxTTL: uint64
|
||||
|
||||
proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] =
|
||||
if spr.len == 0:
|
||||
return err("Empty peer record")
|
||||
let signedEnv = ?SignedPeerRecord.decode(spr).mapErr(x => $x)
|
||||
if signedEnv.data.peerId != peerId:
|
||||
return err("Bad Peer ID")
|
||||
return ok()
|
||||
|
||||
proc sendRegisterResponse(
|
||||
conn: Connection, ttl: uint64
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: Ok, ttl: Opt.some(ttl))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendRegisterResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponse(
|
||||
conn: Connection, s: seq[Register], cookie: Cookie
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(
|
||||
DiscoverResponse(
|
||||
status: Ok, registrations: s, cookie: Opt.some(cookie.encode().buffer)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(DiscoverResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
|
||||
for data in rdv.registered:
|
||||
if data.peerId == peerId:
|
||||
result.inc()
|
||||
|
||||
proc save(
|
||||
rdv: RendezVous, ns: string, peerId: PeerId, r: Register, update: bool = true
|
||||
) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
discard rdv.namespaces.hasKeyOrPut(nsSalted, newSeq[int]())
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == peerId:
|
||||
if update == false:
|
||||
return
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
rdv.registered.add(
|
||||
RegisteredData(
|
||||
peerId: peerId,
|
||||
expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds,
|
||||
data: r,
|
||||
)
|
||||
)
|
||||
rdv.namespaces[nsSalted].add(rdv.registered.high)
|
||||
# rdv.registerEvent.fire()
|
||||
except KeyError as e:
|
||||
doAssert false, "Should have key: " & e.msg
|
||||
|
||||
proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] =
|
||||
trace "Received Register", peerId = conn.peerId, ns = r.ns
|
||||
libp2p_rendezvous_register.inc()
|
||||
if r.ns.len < MinimumNamespaceLen or r.ns.len > MaximumNamespaceLen:
|
||||
return conn.sendRegisterResponseError(InvalidNamespace)
|
||||
let ttl = r.ttl.get(rdv.minTTL)
|
||||
if ttl < rdv.minTTL or ttl > rdv.maxTTL:
|
||||
return conn.sendRegisterResponseError(InvalidTTL)
|
||||
let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId)
|
||||
if pr.isErr():
|
||||
return conn.sendRegisterResponseError(InvalidSignedPeerRecord, pr.error())
|
||||
if rdv.countRegister(conn.peerId) >= RegistrationLimitPerPeer:
|
||||
return conn.sendRegisterResponseError(NotAuthorized, "Registration limit reached")
|
||||
rdv.save(r.ns, conn.peerId, r)
|
||||
libp2p_rendezvous_registered.inc()
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
conn.sendRegisterResponse(ttl)
|
||||
|
||||
proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
|
||||
trace "Received Unregister", peerId = conn.peerId, ns = u.ns
|
||||
let nsSalted = u.ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == conn.peerId:
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
libp2p_rendezvous_registered.dec()
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc discover(
|
||||
rdv: RendezVous, conn: Connection, d: Discover
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
trace "Received Discover", peerId = conn.peerId, ns = d.ns
|
||||
libp2p_rendezvous_discover.inc()
|
||||
if d.ns.isSome() and d.ns.get().len > MaximumNamespaceLen:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
var limit = min(DiscoverLimit, d.limit.get(DiscoverLimit))
|
||||
var cookie =
|
||||
if d.cookie.isSome():
|
||||
try:
|
||||
Cookie.decode(d.cookie.tryGet()).tryGet()
|
||||
except CatchableError:
|
||||
await conn.sendDiscoverResponseError(InvalidCookie)
|
||||
return
|
||||
else:
|
||||
# Start from the current lowest index (inclusive)
|
||||
Cookie(offset: rdv.registered.low().uint64)
|
||||
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get():
|
||||
# Namespace changed: start from the beginning of that namespace
|
||||
cookie = Cookie(offset: rdv.registered.low().uint64)
|
||||
elif cookie.offset < rdv.registered.low().uint64:
|
||||
# Cookie behind available range: reset to current low
|
||||
cookie.offset = rdv.registered.low().uint64
|
||||
elif cookie.offset > (rdv.registered.high() + 1).uint64:
|
||||
# Cookie ahead of available range: reset to one past current high (empty page)
|
||||
cookie.offset = (rdv.registered.high() + 1).uint64
|
||||
let namespaces =
|
||||
if d.ns.isSome():
|
||||
try:
|
||||
rdv.namespaces[d.ns.get() & rdv.salt]
|
||||
except KeyError:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
else:
|
||||
toSeq(max(cookie.offset.int, rdv.registered.offset) .. rdv.registered.high())
|
||||
if namespaces.len() == 0:
|
||||
await conn.sendDiscoverResponse(@[], Cookie())
|
||||
return
|
||||
var nextOffset = cookie.offset
|
||||
let n = Moment.now()
|
||||
var s = collect(newSeq()):
|
||||
for index in namespaces:
|
||||
var reg = rdv.registered[index]
|
||||
if limit == 0:
|
||||
break
|
||||
if reg.expiration < n or index.uint64 < cookie.offset:
|
||||
continue
|
||||
limit.dec()
|
||||
nextOffset = index.uint64 + 1
|
||||
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
|
||||
reg.data
|
||||
rdv.rng.shuffle(s)
|
||||
await conn.sendDiscoverResponse(s, Cookie(offset: nextOffset, ns: d.ns))
|
||||
|
||||
proc advertisePeer(
|
||||
rdv: RendezVous, peer: PeerId, msg: seq[byte]
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
proc advertiseWrap() {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg)
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msgRecv = Message.decode(buf).tryGet()
|
||||
if msgRecv.msgType != MessageType.RegisterResponse:
|
||||
trace "Unexpected register response", peer, msgType = msgRecv.msgType
|
||||
elif msgRecv.registerResponse.tryGet().status != ResponseStatus.Ok:
|
||||
trace "Refuse to register", peer, response = msgRecv.registerResponse
|
||||
else:
|
||||
trace "Successfully registered", peer, response = msgRecv.registerResponse
|
||||
except CatchableError as exc:
|
||||
trace "exception in the advertise", description = exc.msg
|
||||
finally:
|
||||
rdv.sema.release()
|
||||
|
||||
await rdv.sema.acquire()
|
||||
await advertiseWrap()
|
||||
|
||||
proc advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId]
|
||||
) {.async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
if ttl < rdv.minDuration or ttl > rdv.maxDuration:
|
||||
raise newException(AdvertiseError, "Invalid time to live: " & $ttl)
|
||||
|
||||
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
|
||||
raise newException(AdvertiseError, "Wrong Signed Peer Record")
|
||||
|
||||
let
|
||||
r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
|
||||
msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
|
||||
|
||||
rdv.save(ns, rdv.switch.peerInfo.peerId, r)
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peers:
|
||||
trace "Send Advertise", peerId = peer, ns
|
||||
rdv.advertisePeer(peer, msg.buffer).withTimeout(5.seconds)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
method advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration
|
||||
) {.base, async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
await rdv.advertise(ns, ttl, rdv.peers)
|
||||
|
||||
proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
|
||||
let
|
||||
nsSalted = ns & rdv.salt
|
||||
n = Moment.now()
|
||||
try:
|
||||
collect(newSeq()):
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].expiration > n:
|
||||
let res = SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).valueOr:
|
||||
continue
|
||||
res.data
|
||||
except KeyError as exc:
|
||||
@[]
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int, peers: seq[PeerId]
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
var
|
||||
s: Table[PeerId, (PeerRecord, Register)]
|
||||
limit: uint64
|
||||
d = Discover(ns: ns)
|
||||
|
||||
if l <= 0 or l > DiscoverLimit.int:
|
||||
raise newException(AdvertiseError, "Invalid limit")
|
||||
if ns.isSome() and ns.get().len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
limit = l.uint64
|
||||
proc requestPeer(
|
||||
peer: PeerId
|
||||
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
d.limit = Opt.some(limit)
|
||||
d.cookie =
|
||||
if ns.isSome():
|
||||
try:
|
||||
Opt.some(rdv.cookiesSaved[peer][ns.get()])
|
||||
except KeyError, CatchableError:
|
||||
Opt.none(seq[byte])
|
||||
else:
|
||||
Opt.none(seq[byte])
|
||||
await conn.writeLp(
|
||||
encode(Message(msgType: MessageType.Discover, discover: Opt.some(d))).buffer
|
||||
)
|
||||
let
|
||||
buf = await conn.readLp(MaximumMessageLen)
|
||||
msgRcv = Message.decode(buf).valueOr:
|
||||
debug "Message undecodable"
|
||||
return
|
||||
if msgRcv.msgType != MessageType.DiscoverResponse:
|
||||
debug "Unexpected discover response", msgType = msgRcv.msgType
|
||||
return
|
||||
let resp = msgRcv.discoverResponse.valueOr:
|
||||
debug "Discover response is empty"
|
||||
return
|
||||
if resp.status != ResponseStatus.Ok:
|
||||
trace "Cannot discover", ns, status = resp.status, text = resp.text
|
||||
return
|
||||
resp.cookie.withValue(cookie):
|
||||
if ns.isSome:
|
||||
let namespace = ns.get()
|
||||
if cookie.len() < 1000 and
|
||||
rdv.cookiesSaved.hasKeyOrPut(peer, {namespace: cookie}.toTable()):
|
||||
try:
|
||||
rdv.cookiesSaved[peer][namespace] = cookie
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKeyOrPut"
|
||||
for r in resp.registrations:
|
||||
if limit == 0:
|
||||
return
|
||||
let ttl = r.ttl.get(rdv.maxTTL + 1)
|
||||
if ttl > rdv.maxTTL:
|
||||
continue
|
||||
let
|
||||
spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr:
|
||||
continue
|
||||
pr = spr.data
|
||||
if s.hasKey(pr.peerId):
|
||||
let (prSaved, rSaved) =
|
||||
try:
|
||||
s[pr.peerId]
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKey"
|
||||
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or
|
||||
prSaved.seqNo < pr.seqNo:
|
||||
s[pr.peerId] = (pr, r)
|
||||
else:
|
||||
s[pr.peerId] = (pr, r)
|
||||
limit.dec()
|
||||
if ns.isSome():
|
||||
for (_, r) in s.values():
|
||||
rdv.save(ns.get(), peer, r, false)
|
||||
|
||||
for peer in peers:
|
||||
if limit == 0:
|
||||
break
|
||||
if RendezVousCodec notin rdv.switch.peerStore[ProtoBook][peer]:
|
||||
continue
|
||||
try:
|
||||
trace "Send Request", peerId = peer, ns
|
||||
await peer.requestPeer()
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
except DialFailedError as e:
|
||||
trace "failed to dial a peer", description = e.msg
|
||||
except LPStreamError as e:
|
||||
trace "failed to communicate with a peer", description = e.msg
|
||||
return toSeq(s.values()).mapIt(it[0])
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(ns, l, rdv.peers)
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(Opt.none(string), l, rdv.peers)
|
||||
|
||||
proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string, peerIds: seq[PeerId]
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(RendezVousError, "Invalid namespace")
|
||||
|
||||
let msg = encode(
|
||||
Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns)))
|
||||
)
|
||||
|
||||
proc unsubscribePeer(peerId: PeerId) {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peerId, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg.buffer)
|
||||
except CatchableError as exc:
|
||||
trace "exception while unsubscribing", description = exc.msg
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peerIds:
|
||||
unsubscribePeer(peer)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
rdv.unsubscribeLocally(ns)
|
||||
|
||||
await rdv.unsubscribe(ns, rdv.peers)
|
||||
|
||||
proc setup*(rdv: RendezVous, switch: Switch) =
|
||||
rdv.switch = switch
|
||||
proc handlePeer(
|
||||
peerId: PeerId, event: PeerEvent
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
if event.kind == PeerEventKind.Joined:
|
||||
rdv.peers.add(peerId)
|
||||
elif event.kind == PeerEventKind.Left:
|
||||
rdv.peers.keepItIf(it != peerId)
|
||||
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Joined)
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Left)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
if minDuration < MinimumAcceptedDuration:
|
||||
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
|
||||
|
||||
if maxDuration > MaximumDuration:
|
||||
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
|
||||
|
||||
if minDuration >= maxDuration:
|
||||
raise newException(RendezVousError, "Minimum TTL longer than maximum")
|
||||
|
||||
let
|
||||
minTTL = minDuration.seconds.uint64
|
||||
maxTTL = maxDuration.seconds.uint64
|
||||
|
||||
let rdv = T(
|
||||
rng: rng,
|
||||
salt: string.fromBytes(generateBytes(rng[], 8)),
|
||||
registered: initOffsettedSeq[RegisteredData](),
|
||||
expiredDT: Moment.now() - 1.days,
|
||||
#registerEvent: newAsyncEvent(),
|
||||
sema: newAsyncSemaphore(SemaphoreDefaultSize),
|
||||
minDuration: minDuration,
|
||||
maxDuration: maxDuration,
|
||||
minTTL: minTTL,
|
||||
maxTTL: maxTTL,
|
||||
)
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
proc handleStream(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msg = Message.decode(buf).tryGet()
|
||||
case msg.msgType
|
||||
of MessageType.Register:
|
||||
await rdv.register(conn, msg.register.tryGet())
|
||||
of MessageType.RegisterResponse:
|
||||
trace "Got an unexpected Register Response", response = msg.registerResponse
|
||||
of MessageType.Unregister:
|
||||
rdv.unregister(conn, msg.unregister.tryGet())
|
||||
of MessageType.Discover:
|
||||
await rdv.discover(conn, msg.discover.tryGet())
|
||||
of MessageType.DiscoverResponse:
|
||||
trace "Got an unexpected Discover Response", response = msg.discoverResponse
|
||||
except CancelledError as exc:
|
||||
trace "cancelled rendezvous handler"
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "exception in rendezvous handler", description = exc.msg
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
rdv.handler = handleStream
|
||||
rdv.codec = RendezVousCodec
|
||||
return rdv
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
switch: Switch,
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
let rdv = T.new(rng, minDuration, maxDuration)
|
||||
rdv.setup(switch)
|
||||
return rdv
|
||||
|
||||
proc deletesRegister*(
|
||||
rdv: RendezVous, interval = 1.minutes
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "Register timeout", interval:
|
||||
let n = Moment.now()
|
||||
var total = 0
|
||||
rdv.registered.flushIfIt(it.expiration < n)
|
||||
for data in rdv.namespaces.mvalues():
|
||||
data.keepItIf(it >= rdv.registered.offset)
|
||||
total += data.len
|
||||
libp2p_rendezvous_registered.set(int64(total))
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
|
||||
method start*(
|
||||
rdv: RendezVous
|
||||
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if not rdv.registerDeletionLoop.isNil:
|
||||
warn "Starting rendezvous twice"
|
||||
return fut
|
||||
rdv.registerDeletionLoop = rdv.deletesRegister()
|
||||
rdv.started = true
|
||||
fut
|
||||
|
||||
method stop*(rdv: RendezVous): Future[void] {.async: (raises: [], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if rdv.registerDeletionLoop.isNil:
|
||||
warn "Stopping rendezvous without starting it"
|
||||
return fut
|
||||
rdv.started = false
|
||||
rdv.registerDeletionLoop.cancelSoon()
|
||||
rdv.registerDeletionLoop = nil
|
||||
fut
|
||||
@@ -583,7 +583,8 @@ method handshake*(
|
||||
)
|
||||
conn.peerId = pid
|
||||
|
||||
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
|
||||
var tmp =
|
||||
NoiseConnection.new(conn, conn.peerId, conn.observedAddr, conn.localAddr)
|
||||
if initiator:
|
||||
tmp.readCs = handshakeRes.cs2
|
||||
tmp.writeCs = handshakeRes.cs1
|
||||
|
||||
@@ -51,12 +51,14 @@ proc new*(
|
||||
conn: Connection,
|
||||
peerId: PeerId,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
timeout: Duration = DefaultConnectionTimeout,
|
||||
): T =
|
||||
result = T(
|
||||
stream: conn,
|
||||
peerId: peerId,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
closeEvent: conn.closeEvent,
|
||||
timeout: timeout,
|
||||
dir: conn.dir,
|
||||
|
||||
@@ -62,8 +62,15 @@ proc init*(
|
||||
dir: Direction,
|
||||
timeout = DefaultChronosStreamTimeout,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
): ChronosStream =
|
||||
result = C(client: client, timeout: timeout, dir: dir, observedAddr: observedAddr)
|
||||
result = C(
|
||||
client: client,
|
||||
timeout: timeout,
|
||||
dir: dir,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
)
|
||||
result.initStream()
|
||||
|
||||
template withExceptions(body: untyped) =
|
||||
@@ -151,6 +158,19 @@ method closed*(s: ChronosStream): bool =
|
||||
method atEof*(s: ChronosStream): bool =
|
||||
s.client.atEof()
|
||||
|
||||
method closeWrite*(s: ChronosStream) {.async: (raises: []).} =
|
||||
## Close the write side of the TCP connection using half-close
|
||||
if not s.client.closed():
|
||||
try:
|
||||
await s.client.shutdownWait()
|
||||
trace "Write side closed", address = $s.client.remoteAddress(), s
|
||||
except TransportError:
|
||||
# Ignore transport errors during shutdown
|
||||
discard
|
||||
except CatchableError:
|
||||
# Ignore other errors during shutdown
|
||||
discard
|
||||
|
||||
method closeImpl*(s: ChronosStream) {.async: (raises: []).} =
|
||||
trace "Shutting down chronos stream", address = $s.client.remoteAddress(), s
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ type
|
||||
timeoutHandler*: TimeoutHandler # timeout handler
|
||||
peerId*: PeerId
|
||||
observedAddr*: Opt[MultiAddress]
|
||||
localAddr*: Opt[MultiAddress]
|
||||
protocol*: string # protocol used by the connection, used as metrics tag
|
||||
transportDir*: Direction # underlying transport (usually socket) direction
|
||||
when defined(libp2p_agents_metrics):
|
||||
@@ -40,6 +41,12 @@ type
|
||||
|
||||
proc timeoutMonitor(s: Connection) {.async: (raises: []).}
|
||||
|
||||
method closeWrite*(s: Connection): Future[void] {.base, async: (raises: []).} =
|
||||
## Close the write side of the connection
|
||||
## Subclasses should implement this for their specific transport
|
||||
## Default implementation just closes the entire connection
|
||||
await s.close()
|
||||
|
||||
func shortLog*(conn: Connection): string =
|
||||
try:
|
||||
if conn == nil:
|
||||
@@ -133,13 +140,17 @@ when defined(libp2p_agents_metrics):
|
||||
var conn = s
|
||||
while conn != nil:
|
||||
conn.shortAgent = shortAgent
|
||||
conn = conn.getWrapped()
|
||||
let wrapped = conn.getWrapped()
|
||||
if wrapped == conn:
|
||||
break
|
||||
conn = wrapped
|
||||
|
||||
proc new*(
|
||||
C: type Connection,
|
||||
peerId: PeerId,
|
||||
dir: Direction,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
observedAddr: Opt[MultiAddress] = Opt.none(MultiAddress),
|
||||
localAddr: Opt[MultiAddress] = Opt.none(MultiAddress),
|
||||
timeout: Duration = DefaultConnectionTimeout,
|
||||
timeoutHandler: TimeoutHandler = nil,
|
||||
): Connection =
|
||||
@@ -149,6 +160,7 @@ proc new*(
|
||||
timeout: timeout,
|
||||
timeoutHandler: timeoutHandler,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
)
|
||||
|
||||
result.initStream()
|
||||
|
||||
@@ -20,6 +20,7 @@ import chronos, chronicles, metrics
|
||||
import
|
||||
stream/connection,
|
||||
transports/transport,
|
||||
transports/tcptransport,
|
||||
upgrademngrs/upgrade,
|
||||
multistream,
|
||||
multiaddress,
|
||||
@@ -273,6 +274,9 @@ proc accept(s: Switch, transport: Transport) {.async: (raises: []).} =
|
||||
conn =
|
||||
try:
|
||||
await transport.accept()
|
||||
except CancelledError as exc:
|
||||
slot.release()
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
slot.release()
|
||||
raise
|
||||
@@ -351,7 +355,17 @@ proc start*(s: Switch) {.public, async: (raises: [CancelledError, LPError]).} =
|
||||
s.peerInfo.listenAddrs.keepItIf(it notin addrs)
|
||||
|
||||
if addrs.len > 0 or t.running:
|
||||
startFuts.add(t.start(addrs))
|
||||
let fut = t.start(addrs)
|
||||
startFuts.add(fut)
|
||||
if t of TcpTransport:
|
||||
await fut
|
||||
s.acceptFuts.add(s.accept(t))
|
||||
s.peerInfo.listenAddrs &= t.addrs
|
||||
|
||||
# some transports require some services to be running
|
||||
# in order to finish their startup process
|
||||
for service in s.services:
|
||||
discard await service.setup(s)
|
||||
|
||||
await allFutures(startFuts)
|
||||
|
||||
@@ -364,12 +378,11 @@ proc start*(s: Switch) {.public, async: (raises: [CancelledError, LPError]).} =
|
||||
|
||||
for t in s.transports: # for each transport
|
||||
if t.addrs.len > 0 or t.running:
|
||||
if t of TcpTransport:
|
||||
continue # already added previously
|
||||
s.acceptFuts.add(s.accept(t))
|
||||
s.peerInfo.listenAddrs &= t.addrs
|
||||
|
||||
for service in s.services:
|
||||
discard await service.setup(s)
|
||||
|
||||
await s.peerInfo.update()
|
||||
await s.ms.start()
|
||||
s.started = true
|
||||
|
||||
@@ -52,7 +52,7 @@ proc listenAddress(self: MemoryTransport, ma: MultiAddress): MultiAddress =
|
||||
|
||||
method start*(
|
||||
self: MemoryTransport, addrs: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError]).} =
|
||||
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
if self.running:
|
||||
return
|
||||
|
||||
|
||||
@@ -36,12 +36,20 @@ type QuicStream* = ref object of P2PConnection
|
||||
cached: seq[byte]
|
||||
|
||||
proc new(
|
||||
_: type QuicStream, stream: Stream, oaddr: Opt[MultiAddress], peerId: PeerId
|
||||
_: type QuicStream,
|
||||
stream: Stream,
|
||||
oaddr: Opt[MultiAddress],
|
||||
laddr: Opt[MultiAddress],
|
||||
peerId: PeerId,
|
||||
): QuicStream =
|
||||
let quicstream = QuicStream(stream: stream, observedAddr: oaddr, peerId: peerId)
|
||||
let quicstream =
|
||||
QuicStream(stream: stream, observedAddr: oaddr, localAddr: laddr, peerId: peerId)
|
||||
procCall P2PConnection(quicstream).initStream()
|
||||
quicstream
|
||||
|
||||
method getWrapped*(self: QuicStream): P2PConnection =
|
||||
self
|
||||
|
||||
template mapExceptions(body: untyped) =
|
||||
try:
|
||||
body
|
||||
@@ -53,15 +61,23 @@ template mapExceptions(body: untyped) =
|
||||
method readOnce*(
|
||||
stream: QuicStream, pbytes: pointer, nbytes: int
|
||||
): Future[int] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
try:
|
||||
if stream.cached.len == 0:
|
||||
if stream.cached.len == 0:
|
||||
try:
|
||||
stream.cached = await stream.stream.read()
|
||||
result = min(nbytes, stream.cached.len)
|
||||
copyMem(pbytes, addr stream.cached[0], result)
|
||||
stream.cached = stream.cached[result ..^ 1]
|
||||
libp2p_network_bytes.inc(result.int64, labelValues = ["in"])
|
||||
except CatchableError as exc:
|
||||
raise newLPStreamEOFError()
|
||||
if stream.cached.len == 0:
|
||||
raise newLPStreamEOFError()
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except LPStreamEOFError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise (ref LPStreamError)(msg: "error in readOnce: " & exc.msg, parent: exc)
|
||||
|
||||
let toRead = min(nbytes, stream.cached.len)
|
||||
copyMem(pbytes, addr stream.cached[0], toRead)
|
||||
stream.cached = stream.cached[toRead ..^ 1]
|
||||
libp2p_network_bytes.inc(toRead.int64, labelValues = ["in"])
|
||||
return toRead
|
||||
|
||||
{.push warning[LockLevel]: off.}
|
||||
method write*(
|
||||
@@ -72,6 +88,13 @@ method write*(
|
||||
|
||||
{.pop.}
|
||||
|
||||
method closeWrite*(stream: QuicStream) {.async: (raises: []).} =
|
||||
## Close the write side of the QUIC stream
|
||||
try:
|
||||
await stream.stream.closeWrite()
|
||||
except CatchableError as exc:
|
||||
discard
|
||||
|
||||
method closeImpl*(stream: QuicStream) {.async: (raises: []).} =
|
||||
try:
|
||||
await stream.stream.close()
|
||||
@@ -82,8 +105,11 @@ method closeImpl*(stream: QuicStream) {.async: (raises: []).} =
|
||||
# Session
|
||||
type QuicSession* = ref object of P2PConnection
|
||||
connection: QuicConnection
|
||||
streams: seq[QuicStream]
|
||||
|
||||
method close*(session: QuicSession) {.async: (raises: []).} =
|
||||
for s in session.streams:
|
||||
await s.close()
|
||||
safeClose(session.connection)
|
||||
await procCall P2PConnection(session).close()
|
||||
|
||||
@@ -98,19 +124,33 @@ proc getStream*(
|
||||
of Direction.Out:
|
||||
stream = await session.connection.openStream()
|
||||
await stream.write(@[]) # QUIC streams do not exist until data is sent
|
||||
return QuicStream.new(stream, session.observedAddr, session.peerId)
|
||||
|
||||
let qs =
|
||||
QuicStream.new(stream, session.observedAddr, session.localAddr, session.peerId)
|
||||
when defined(libp2p_agents_metrics):
|
||||
qs.shortAgent = session.shortAgent
|
||||
|
||||
session.streams.add(qs)
|
||||
return qs
|
||||
except CatchableError as exc:
|
||||
# TODO: incomingStream is using {.async.} with no raises
|
||||
raise (ref QuicTransportError)(msg: "error in getStream: " & exc.msg, parent: exc)
|
||||
|
||||
method getWrapped*(self: QuicSession): P2PConnection =
|
||||
nil
|
||||
self
|
||||
|
||||
# Muxer
|
||||
type QuicMuxer = ref object of Muxer
|
||||
quicSession: QuicSession
|
||||
handleFut: Future[void]
|
||||
|
||||
when defined(libp2p_agents_metrics):
|
||||
method setShortAgent*(m: QuicMuxer, shortAgent: string) =
|
||||
m.quicSession.shortAgent = shortAgent
|
||||
for s in m.quicSession.streams:
|
||||
s.shortAgent = shortAgent
|
||||
m.connection.shortAgent = shortAgent
|
||||
|
||||
method newStream*(
|
||||
m: QuicMuxer, name: string = "", lazy: bool = false
|
||||
): Future[P2PConnection] {.
|
||||
@@ -129,7 +169,7 @@ proc handleStream(m: QuicMuxer, chann: QuicStream) {.async: (raises: []).} =
|
||||
trace "finished handling stream"
|
||||
doAssert(chann.closed, "connection not closed by handler!")
|
||||
except CatchableError as exc:
|
||||
trace "Exception in mplex stream handler", msg = exc.msg
|
||||
trace "Exception in quic stream handler", msg = exc.msg
|
||||
await chann.close()
|
||||
|
||||
method handle*(m: QuicMuxer): Future[void] {.async: (raises: []).} =
|
||||
@@ -138,7 +178,7 @@ method handle*(m: QuicMuxer): Future[void] {.async: (raises: []).} =
|
||||
let incomingStream = await m.quicSession.getStream(Direction.In)
|
||||
asyncSpawn m.handleStream(incomingStream)
|
||||
except CatchableError as exc:
|
||||
trace "Exception in mplex handler", msg = exc.msg
|
||||
trace "Exception in quic handler", msg = exc.msg
|
||||
|
||||
method close*(m: QuicMuxer) {.async: (raises: []).} =
|
||||
try:
|
||||
@@ -155,7 +195,7 @@ type CertGenerator =
|
||||
|
||||
type QuicTransport* = ref object of Transport
|
||||
listener: Listener
|
||||
client: QuicClient
|
||||
client: Opt[QuicClient]
|
||||
privateKey: PrivateKey
|
||||
connections: seq[P2PConnection]
|
||||
rng: ref HmacDrbgContext
|
||||
@@ -208,27 +248,33 @@ method handles*(transport: QuicTransport, address: MultiAddress): bool {.raises:
|
||||
return false
|
||||
QUIC_V1.match(address)
|
||||
|
||||
method start*(
|
||||
self: QuicTransport, addrs: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError]).} =
|
||||
doAssert self.listener.isNil, "start() already called"
|
||||
#TODO handle multiple addr
|
||||
|
||||
proc makeConfig(self: QuicTransport): TLSConfig =
|
||||
let pubkey = self.privateKey.getPublicKey().valueOr:
|
||||
doAssert false, "could not obtain public key"
|
||||
return
|
||||
|
||||
try:
|
||||
if self.rng.isNil:
|
||||
self.rng = newRng()
|
||||
let cert = self.certGenerator(KeyPair(seckey: self.privateKey, pubkey: pubkey))
|
||||
let tlsConfig = TLSConfig.init(
|
||||
cert.certificate, cert.privateKey, @[alpn], Opt.some(makeCertificateVerifier())
|
||||
)
|
||||
return tlsConfig
|
||||
|
||||
let cert = self.certGenerator(KeyPair(seckey: self.privateKey, pubkey: pubkey))
|
||||
let tlsConfig = TLSConfig.init(
|
||||
cert.certificate, cert.privateKey, @[alpn], Opt.some(makeCertificateVerifier())
|
||||
)
|
||||
self.client = QuicClient.init(tlsConfig, rng = self.rng)
|
||||
self.listener =
|
||||
QuicServer.init(tlsConfig, rng = self.rng).listen(initTAddress(addrs[0]).tryGet)
|
||||
proc getRng(self: QuicTransport): ref HmacDrbgContext =
|
||||
if self.rng.isNil:
|
||||
self.rng = newRng()
|
||||
|
||||
return self.rng
|
||||
|
||||
method start*(
|
||||
self: QuicTransport, addrs: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
doAssert self.listener.isNil, "start() already called"
|
||||
# TODO(#1663): handle multiple addr
|
||||
|
||||
try:
|
||||
self.listener = QuicServer.init(self.makeConfig(), rng = self.getRng()).listen(
|
||||
initTAddress(addrs[0]).tryGet
|
||||
)
|
||||
await procCall Transport(self).start(addrs)
|
||||
self.addrs[0] =
|
||||
MultiAddress.init(self.listener.localAddress(), IPPROTO_UDP).tryGet() &
|
||||
@@ -249,27 +295,36 @@ method start*(
|
||||
self.running = true
|
||||
|
||||
method stop*(transport: QuicTransport) {.async: (raises: []).} =
|
||||
if transport.running:
|
||||
for c in transport.connections:
|
||||
await c.close()
|
||||
await procCall Transport(transport).stop()
|
||||
let conns = transport.connections[0 .. ^1]
|
||||
for c in conns:
|
||||
await c.close()
|
||||
|
||||
if not transport.listener.isNil:
|
||||
try:
|
||||
await transport.listener.stop()
|
||||
except CatchableError as exc:
|
||||
trace "Error shutting down Quic transport", description = exc.msg
|
||||
transport.listener.destroy()
|
||||
transport.running = false
|
||||
transport.listener = nil
|
||||
|
||||
transport.client = Opt.none(QuicClient)
|
||||
await procCall Transport(transport).stop()
|
||||
|
||||
proc wrapConnection(
|
||||
transport: QuicTransport, connection: QuicConnection
|
||||
): QuicSession {.raises: [TransportOsError, MaError].} =
|
||||
let
|
||||
remoteAddr = connection.remoteAddress()
|
||||
observedAddr =
|
||||
MultiAddress.init(remoteAddr, IPPROTO_UDP).get() &
|
||||
MultiAddress.init(connection.remoteAddress(), IPPROTO_UDP).get() &
|
||||
MultiAddress.init("/quic-v1").get()
|
||||
session = QuicSession(connection: connection, observedAddr: Opt.some(observedAddr))
|
||||
localAddr =
|
||||
MultiAddress.init(connection.localAddress(), IPPROTO_UDP).get() &
|
||||
MultiAddress.init("/quic-v1").get()
|
||||
session = QuicSession(
|
||||
connection: connection,
|
||||
observedAddr: Opt.some(observedAddr),
|
||||
localAddr: Opt.some(localAddr),
|
||||
)
|
||||
|
||||
session.initStream()
|
||||
|
||||
@@ -289,12 +344,12 @@ method accept*(
|
||||
): Future[connection.Connection] {.
|
||||
async: (raises: [transport.TransportError, CancelledError])
|
||||
.} =
|
||||
doAssert not self.listener.isNil, "call start() before calling accept()"
|
||||
|
||||
if not self.running:
|
||||
# stop accept only when transport is stopped (not when error occurs)
|
||||
raise newException(QuicTransportAcceptStopped, "Quic transport stopped")
|
||||
|
||||
doAssert not self.listener.isNil, "call start() before calling accept()"
|
||||
|
||||
try:
|
||||
let connection = await self.listener.accept()
|
||||
return self.wrapConnection(connection)
|
||||
@@ -318,7 +373,11 @@ method dial*(
|
||||
async: (raises: [transport.TransportError, CancelledError])
|
||||
.} =
|
||||
try:
|
||||
let quicConnection = await self.client.dial(initTAddress(address).tryGet)
|
||||
if not self.client.isSome:
|
||||
self.client = Opt.some(QuicClient.init(self.makeConfig(), rng = self.getRng()))
|
||||
|
||||
let client = self.client.get()
|
||||
let quicConnection = await client.dial(initTAddress(address).tryGet)
|
||||
return self.wrapConnection(quicConnection)
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
|
||||
@@ -47,6 +47,7 @@ proc connHandler*(
|
||||
self: TcpTransport,
|
||||
client: StreamTransport,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
dir: Direction,
|
||||
): Connection =
|
||||
trace "Handling tcp connection",
|
||||
@@ -59,6 +60,7 @@ proc connHandler*(
|
||||
client = client,
|
||||
dir = dir,
|
||||
observedAddr = observedAddr,
|
||||
localAddr = localAddr,
|
||||
timeout = self.connectionsTimeout,
|
||||
)
|
||||
)
|
||||
@@ -107,7 +109,7 @@ proc new*(
|
||||
|
||||
method start*(
|
||||
self: TcpTransport, addrs: seq[MultiAddress]
|
||||
): Future[void] {.async: (raises: [LPError, transport.TransportError]).} =
|
||||
): Future[void] {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
## Start transport listening to the given addresses - for dial-only transports,
|
||||
## start with an empty list
|
||||
|
||||
@@ -267,18 +269,22 @@ method accept*(
|
||||
safeCloseWait(transp)
|
||||
raise newTransportClosedError()
|
||||
|
||||
let remote =
|
||||
let (localAddr, observedAddr) =
|
||||
try:
|
||||
transp.remoteAddress
|
||||
(
|
||||
MultiAddress.init(transp.localAddress).expect(
|
||||
"Can initialize from local address"
|
||||
),
|
||||
MultiAddress.init(transp.remoteAddress).expect(
|
||||
"Can initialize from remote address"
|
||||
),
|
||||
)
|
||||
except TransportOsError as exc:
|
||||
# The connection had errors / was closed before `await` returned control
|
||||
safeCloseWait(transp)
|
||||
debug "Cannot read remote address", description = exc.msg
|
||||
debug "Cannot read address", description = exc.msg
|
||||
return nil
|
||||
|
||||
let observedAddr =
|
||||
MultiAddress.init(remote).expect("Can initialize from remote address")
|
||||
self.connHandler(transp, Opt.some(observedAddr), Direction.In)
|
||||
self.connHandler(transp, Opt.some(observedAddr), Opt.some(localAddr), Direction.In)
|
||||
|
||||
method dial*(
|
||||
self: TcpTransport,
|
||||
@@ -320,14 +326,17 @@ method dial*(
|
||||
safeCloseWait(transp)
|
||||
raise newTransportClosedError()
|
||||
|
||||
let observedAddr =
|
||||
let (observedAddr, localAddr) =
|
||||
try:
|
||||
MultiAddress.init(transp.remoteAddress).expect("remote address is valid")
|
||||
(
|
||||
MultiAddress.init(transp.remoteAddress).expect("remote address is valid"),
|
||||
MultiAddress.init(transp.localAddress).expect("local address is valid"),
|
||||
)
|
||||
except TransportOsError as exc:
|
||||
safeCloseWait(transp)
|
||||
raise (ref TcpTransportError)(msg: "MultiAddress.init error in dial: " & exc.msg)
|
||||
|
||||
self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
|
||||
self.connHandler(transp, Opt.some(observedAddr), Opt.some(localAddr), Direction.Out)
|
||||
|
||||
method handles*(t: TcpTransport, address: MultiAddress): bool {.raises: [].} =
|
||||
if procCall Transport(t).handles(address):
|
||||
|
||||
@@ -801,6 +801,44 @@ cleanup:
|
||||
return ret_code;
|
||||
}
|
||||
|
||||
cert_error_t cert_new_key_t(cert_buffer *seckey, cert_key_t *out) {
|
||||
BIO *bio = NULL;
|
||||
cert_error_t ret_code = CERT_SUCCESS;
|
||||
|
||||
if (out == NULL) {
|
||||
return CERT_ERROR_NULL_PARAM;
|
||||
}
|
||||
|
||||
struct cert_key_s *key = calloc(1, sizeof(struct cert_key_s));
|
||||
if (key == NULL) {
|
||||
return CERT_ERROR_MEMORY;
|
||||
}
|
||||
|
||||
bio = BIO_new_mem_buf(seckey->data, seckey->len);
|
||||
if (!bio) {
|
||||
ret_code = CERT_ERROR_BIO_GEN;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
EVP_PKEY *pkey = d2i_PrivateKey_bio(bio, NULL);
|
||||
|
||||
key->pkey = pkey;
|
||||
*out = key;
|
||||
|
||||
cleanup:
|
||||
if (bio)
|
||||
BIO_free(bio);
|
||||
|
||||
if (ret_code != CERT_SUCCESS && *out) {
|
||||
if (pkey)
|
||||
EVP_PKEY_free(pkey);
|
||||
free(key);
|
||||
*out = NULL;
|
||||
}
|
||||
|
||||
return ret_code;
|
||||
}
|
||||
|
||||
cert_error_t cert_serialize_privk(cert_key_t key, cert_buffer **out,
|
||||
cert_format_t format) {
|
||||
BIO *bio = NULL;
|
||||
|
||||
@@ -106,6 +106,16 @@ cert_error_t cert_init_drbg(const char *seed, size_t seed_len,
|
||||
*/
|
||||
cert_error_t cert_generate_key(cert_context_t ctx, cert_key_t *out);
|
||||
|
||||
/**
|
||||
* Copy DER formated seckey to a cert_key_t
|
||||
*
|
||||
* @param seckey Private Key bytes in DER format
|
||||
* @param out Pointer to store the key as cert_key_t
|
||||
*
|
||||
* @return CERT_SUCCESS on successful execution, an error code otherwise
|
||||
*/
|
||||
cert_error_t cert_new_key_t(cert_buffer *seckey, cert_key_t *out);
|
||||
|
||||
/**
|
||||
* Serialize a key's private key to a format
|
||||
*
|
||||
|
||||
@@ -42,6 +42,10 @@ proc cert_generate_key*(
|
||||
ctx: cert_context_t, out_arg: ptr cert_key_t
|
||||
): cert_error_t {.cdecl, importc: "cert_generate_key".}
|
||||
|
||||
proc cert_new_key_t*(
|
||||
seckey: ptr cert_buffer, certKey: ptr cert_key_t
|
||||
): cert_error_t {.cdecl, importc: "cert_new_key_t".}
|
||||
|
||||
proc cert_serialize_privk*(
|
||||
key: cert_key_t, out_arg: ptr ptr cert_buffer, format: cert_format_t
|
||||
): cert_error_t {.cdecl, importc: "cert_serialize_privk".}
|
||||
|
||||
@@ -18,7 +18,6 @@ import
|
||||
transport,
|
||||
tcptransport,
|
||||
../switch,
|
||||
../autotls/service,
|
||||
../builders,
|
||||
../stream/[lpstream, connection, chronosstream],
|
||||
../multiaddress,
|
||||
@@ -238,7 +237,9 @@ method dial*(
|
||||
try:
|
||||
transp = await connectToTorServer(self.transportAddress)
|
||||
await dialPeer(transp, address)
|
||||
return self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
|
||||
return self.tcpTransport.connHandler(
|
||||
transp, Opt.none(MultiAddress), Opt.none(MultiAddress), Direction.Out
|
||||
)
|
||||
except CancelledError as e:
|
||||
safeCloseWait(transp)
|
||||
raise e
|
||||
@@ -250,7 +251,7 @@ method dial*(
|
||||
|
||||
method start*(
|
||||
self: TorTransport, addrs: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError]).} =
|
||||
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
## listen on the transport
|
||||
##
|
||||
|
||||
@@ -304,8 +305,8 @@ proc new*(
|
||||
flags: set[ServerFlags] = {},
|
||||
): TorSwitch {.raises: [LPError], public.} =
|
||||
var builder = SwitchBuilder.new().withRng(rng).withTransport(
|
||||
proc(upgr: Upgrade, privateKey: PrivateKey, autotls: AutotlsService): Transport =
|
||||
TorTransport.new(torServer, flags, upgr)
|
||||
proc(config: TransportConfig): Transport =
|
||||
TorTransport.new(torServer, flags, config.upgr)
|
||||
)
|
||||
if addresses.len != 0:
|
||||
builder = builder.withAddresses(addresses)
|
||||
|
||||
@@ -18,9 +18,9 @@ import
|
||||
../multicodec,
|
||||
../muxers/muxer,
|
||||
../upgrademngrs/upgrade,
|
||||
../protocols/connectivity/autonat/core
|
||||
../protocols/connectivity/autonat/types
|
||||
|
||||
export core.NetworkReachability
|
||||
export types.NetworkReachability
|
||||
|
||||
logScope:
|
||||
topics = "libp2p transport"
|
||||
@@ -42,7 +42,7 @@ proc newTransportClosedError*(parent: ref Exception = nil): ref TransportError =
|
||||
|
||||
method start*(
|
||||
self: Transport, addrs: seq[MultiAddress]
|
||||
) {.base, async: (raises: [LPError, TransportError]).} =
|
||||
) {.base, async: (raises: [LPError, TransportError, CancelledError]).} =
|
||||
## start the transport
|
||||
##
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import results
|
||||
import chronos, chronicles
|
||||
import
|
||||
transport,
|
||||
../autotls/service,
|
||||
../errors,
|
||||
../wire,
|
||||
../multicodec,
|
||||
@@ -32,7 +33,10 @@ logScope:
|
||||
|
||||
export transport, websock, results
|
||||
|
||||
const DefaultHeadersTimeout = 3.seconds
|
||||
const
|
||||
DefaultHeadersTimeout = 3.seconds
|
||||
DefaultAutotlsWaitTimeout = 3.seconds
|
||||
DefaultAutotlsRetries = 3
|
||||
|
||||
type
|
||||
WsStream = ref object of Connection
|
||||
@@ -51,10 +55,16 @@ proc new*(
|
||||
session: WSSession,
|
||||
dir: Direction,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
localAddr: Opt[MultiAddress],
|
||||
timeout = 10.minutes,
|
||||
): T =
|
||||
let stream =
|
||||
T(session: session, timeout: timeout, dir: dir, observedAddr: observedAddr)
|
||||
let stream = T(
|
||||
session: session,
|
||||
timeout: timeout,
|
||||
dir: dir,
|
||||
observedAddr: observedAddr,
|
||||
localAddr: localAddr,
|
||||
)
|
||||
|
||||
stream.initStream()
|
||||
return stream
|
||||
@@ -105,11 +115,11 @@ type WsTransport* = ref object of Transport
|
||||
httpservers: seq[HttpServer]
|
||||
wsserver: WSServer
|
||||
connections: array[Direction, seq[WsStream]]
|
||||
|
||||
acceptFuts: seq[Future[HttpRequest]]
|
||||
|
||||
tlsPrivateKey: TLSPrivateKey
|
||||
tlsCertificate: TLSCertificate
|
||||
tlsPrivateKey*: TLSPrivateKey
|
||||
tlsCertificate*: TLSCertificate
|
||||
autotls: AutotlsService
|
||||
tlsFlags: set[TLSFlags]
|
||||
flags: set[ServerFlags]
|
||||
handshakeTimeout: Duration
|
||||
@@ -121,7 +131,7 @@ proc secure*(self: WsTransport): bool =
|
||||
|
||||
method start*(
|
||||
self: WsTransport, addrs: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError]).} =
|
||||
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
## listen on the transport
|
||||
##
|
||||
|
||||
@@ -129,8 +139,25 @@ method start*(
|
||||
warn "WS transport already running"
|
||||
return
|
||||
|
||||
await procCall Transport(self).start(addrs)
|
||||
when defined(libp2p_autotls_support):
|
||||
if not self.secure and not self.autotls.isNil():
|
||||
if not await self.autotls.running.wait().withTimeout(DefaultAutotlsWaitTimeout):
|
||||
error "Unable to upgrade, autotls not running"
|
||||
await self.stop()
|
||||
return
|
||||
|
||||
trace "Waiting for autotls certificate"
|
||||
try:
|
||||
let autotlsCert = await self.autotls.getCertWhenReady()
|
||||
self.tlsCertificate = autotlsCert.cert
|
||||
self.tlsPrivateKey = autotlsCert.privkey
|
||||
except AutoTLSError as exc:
|
||||
raise newException(LPError, exc.msg, exc)
|
||||
except TLSStreamProtocolError as exc:
|
||||
raise newException(LPError, exc.msg, exc)
|
||||
|
||||
trace "Starting WS transport"
|
||||
await procCall Transport(self).start(addrs)
|
||||
|
||||
self.wsserver = WSServer.new(factories = self.factories, rng = self.rng)
|
||||
|
||||
@@ -140,7 +167,7 @@ method start*(
|
||||
if self.secure:
|
||||
true
|
||||
else:
|
||||
warn "Trying to listen on a WSS address without setting certificate!"
|
||||
warn "Trying to listen on a WSS address without setting certificate or autotls!"
|
||||
false
|
||||
else:
|
||||
false
|
||||
@@ -181,8 +208,6 @@ method start*(
|
||||
|
||||
trace "Listening on", addresses = self.addrs
|
||||
|
||||
self.running = true
|
||||
|
||||
method stop*(self: WsTransport) {.async: (raises: []).} =
|
||||
## stop the transport
|
||||
##
|
||||
@@ -193,18 +218,16 @@ method stop*(self: WsTransport) {.async: (raises: []).} =
|
||||
trace "Stopping WS transport"
|
||||
await procCall Transport(self).stop() # call base
|
||||
|
||||
checkFutures(
|
||||
await allFinished(
|
||||
self.connections[Direction.In].mapIt(it.close()) &
|
||||
self.connections[Direction.Out].mapIt(it.close())
|
||||
)
|
||||
discard await allFinished(
|
||||
self.connections[Direction.In].mapIt(it.close()) &
|
||||
self.connections[Direction.Out].mapIt(it.close())
|
||||
)
|
||||
|
||||
var toWait: seq[Future[void]]
|
||||
for fut in self.acceptFuts:
|
||||
if not fut.finished:
|
||||
toWait.add(fut.cancelAndWait())
|
||||
elif fut.done:
|
||||
elif fut.completed:
|
||||
toWait.add(fut.read().stream.closeWait())
|
||||
|
||||
for server in self.httpservers:
|
||||
@@ -222,9 +245,8 @@ proc connHandler(
|
||||
self: WsTransport, stream: WSSession, secure: bool, dir: Direction
|
||||
): Future[Connection] {.async: (raises: [CatchableError]).} =
|
||||
## Returning CatchableError is fine because we later handle different exceptions.
|
||||
##
|
||||
|
||||
let observedAddr =
|
||||
let (observedAddr, localAddr) =
|
||||
try:
|
||||
let
|
||||
codec =
|
||||
@@ -233,15 +255,19 @@ proc connHandler(
|
||||
else:
|
||||
MultiAddress.init("/ws")
|
||||
remoteAddr = stream.stream.reader.tsource.remoteAddress
|
||||
localAddr = stream.stream.reader.tsource.localAddress
|
||||
|
||||
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet()
|
||||
(
|
||||
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet(),
|
||||
MultiAddress.init(localAddr).tryGet() & codec.tryGet(),
|
||||
)
|
||||
except CatchableError as exc:
|
||||
trace "Failed to create observedAddr", description = exc.msg
|
||||
trace "Failed to create observedAddr or listenAddr", description = exc.msg
|
||||
if not (isNil(stream) and stream.stream.reader.closed):
|
||||
safeClose(stream)
|
||||
raise exc
|
||||
|
||||
let conn = WsStream.new(stream, dir, Opt.some(observedAddr))
|
||||
let conn = WsStream.new(stream, dir, Opt.some(observedAddr), Opt.some(localAddr))
|
||||
|
||||
self.connections[dir].add(conn)
|
||||
proc onClose() {.async: (raises: []).} =
|
||||
@@ -255,8 +281,14 @@ proc connHandler(
|
||||
method accept*(
|
||||
self: WsTransport
|
||||
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
|
||||
## accept a new WS connection
|
||||
##
|
||||
trace "WsTransport accept"
|
||||
|
||||
# wstransport can only start accepting connections after autotls is done
|
||||
# if autotls is not present, self.running will be true right after start is called
|
||||
var retries = 0
|
||||
while not self.running and retries < DefaultAutotlsRetries:
|
||||
retries += 1
|
||||
await sleepAsync(DefaultAutotlsWaitTimeout)
|
||||
|
||||
if not self.running:
|
||||
raise newTransportClosedError()
|
||||
@@ -282,9 +314,8 @@ method accept*(
|
||||
let req = await finished
|
||||
|
||||
try:
|
||||
let
|
||||
wstransp = await self.wsserver.handleRequest(req).wait(self.handshakeTimeout)
|
||||
isSecure = self.httpservers[index].secure
|
||||
let wstransp = await self.wsserver.handleRequest(req).wait(self.handshakeTimeout)
|
||||
let isSecure = self.httpservers[index].secure
|
||||
|
||||
return await self.connHandler(wstransp, isSecure, Direction.In)
|
||||
except CatchableError as exc:
|
||||
@@ -329,12 +360,11 @@ method dial*(
|
||||
|
||||
try:
|
||||
let secure = WSS.match(address)
|
||||
let initAddress = address.initTAddress().tryGet()
|
||||
debug "creating websocket",
|
||||
address = initAddress, secure = secure, hostName = hostname
|
||||
transp = await WebSocket.connect(
|
||||
address.initTAddress().tryGet(),
|
||||
"",
|
||||
secure = secure,
|
||||
hostName = hostname,
|
||||
flags = self.tlsFlags,
|
||||
initAddress, "", secure = secure, hostName = hostname, flags = self.tlsFlags
|
||||
)
|
||||
return await self.connHandler(transp, secure, Direction.Out)
|
||||
except CancelledError as e:
|
||||
@@ -356,6 +386,7 @@ proc new*(
|
||||
upgrade: Upgrade,
|
||||
tlsPrivateKey: TLSPrivateKey,
|
||||
tlsCertificate: TLSCertificate,
|
||||
autotls: AutotlsService,
|
||||
tlsFlags: set[TLSFlags] = {},
|
||||
flags: set[ServerFlags] = {},
|
||||
factories: openArray[ExtFactory] = [],
|
||||
@@ -368,6 +399,7 @@ proc new*(
|
||||
upgrader: upgrade,
|
||||
tlsPrivateKey: tlsPrivateKey,
|
||||
tlsCertificate: tlsCertificate,
|
||||
autotls: autotls,
|
||||
tlsFlags: tlsFlags,
|
||||
flags: flags,
|
||||
factories: @factories,
|
||||
@@ -389,6 +421,7 @@ proc new*(
|
||||
upgrade = upgrade,
|
||||
tlsPrivateKey = nil,
|
||||
tlsCertificate = nil,
|
||||
autotls = nil,
|
||||
flags = flags,
|
||||
factories = @factories,
|
||||
rng = rng,
|
||||
|
||||
74
libp2p/utils/ipaddr.nim
Normal file
74
libp2p/utils/ipaddr.nim
Normal file
@@ -0,0 +1,74 @@
|
||||
import net, strutils
|
||||
|
||||
import ../switch, ../multiaddress, ../multicodec
|
||||
|
||||
proc isIPv4*(ip: IpAddress): bool =
|
||||
ip.family == IpAddressFamily.IPv4
|
||||
|
||||
proc isIPv6*(ip: IpAddress): bool =
|
||||
ip.family == IpAddressFamily.IPv6
|
||||
|
||||
proc isPrivate*(ip: string): bool {.raises: [ValueError].} =
|
||||
ip.startsWith("10.") or
|
||||
(ip.startsWith("172.") and parseInt(ip.split(".")[1]) in 16 .. 31) or
|
||||
ip.startsWith("192.168.") or ip.startsWith("127.") or ip.startsWith("169.254.")
|
||||
|
||||
proc isPrivate*(ip: IpAddress): bool {.raises: [ValueError].} =
|
||||
isPrivate($ip)
|
||||
|
||||
proc isPublic*(ip: string): bool {.raises: [ValueError].} =
|
||||
not isPrivate(ip)
|
||||
|
||||
proc isPublic*(ip: IpAddress): bool {.raises: [ValueError].} =
|
||||
isPublic($ip)
|
||||
|
||||
proc getPublicIPAddress*(): IpAddress {.raises: [OSError, ValueError].} =
|
||||
let ip =
|
||||
try:
|
||||
getPrimaryIPAddr()
|
||||
except OSError as exc:
|
||||
raise exc
|
||||
except ValueError as exc:
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
raise newException(OSError, "Could not get primary IP address")
|
||||
if not ip.isIPv4():
|
||||
raise newException(ValueError, "Host does not have an IPv4 address")
|
||||
if not ip.isPublic():
|
||||
raise newException(ValueError, "Host does not have a public IPv4 address")
|
||||
ip
|
||||
|
||||
proc ipAddrMatches*(
|
||||
lookup: MultiAddress, addrs: seq[MultiAddress], ip4: bool = true
|
||||
): bool =
|
||||
## Checks ``lookup``'s IP is in any of addrs
|
||||
|
||||
let ipType =
|
||||
if ip4:
|
||||
multiCodec("ip4")
|
||||
else:
|
||||
multiCodec("ip6")
|
||||
|
||||
let lookup = lookup.getPart(ipType).valueOr:
|
||||
return false
|
||||
|
||||
for ma in addrs:
|
||||
ma[0].withValue(ipAddr):
|
||||
if ipAddr == lookup:
|
||||
return true
|
||||
false
|
||||
|
||||
proc ipSupport*(addrs: seq[MultiAddress]): (bool, bool) =
|
||||
## Returns ipv4 and ipv6 support status of a list of MultiAddresses
|
||||
|
||||
var ipv4 = false
|
||||
var ipv6 = false
|
||||
|
||||
for ma in addrs:
|
||||
ma[0].withValue(addrIp):
|
||||
if IP4.match(addrIp):
|
||||
ipv4 = true
|
||||
elif IP6.match(addrIp):
|
||||
ipv6 = true
|
||||
|
||||
(ipv4, ipv6)
|
||||
3
performance/.gitignore
vendored
Normal file
3
performance/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
nimbledeps/*
|
||||
output/
|
||||
bin/
|
||||
25
performance/Dockerfile
Normal file
25
performance/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
# Create the build image
|
||||
FROM nimlang/nim:2.2.4-alpine-regular AS build
|
||||
|
||||
WORKDIR /node
|
||||
|
||||
COPY libp2p.nimble config.nims ./
|
||||
RUN git config --global http.sslVerify false
|
||||
RUN nimble install -y
|
||||
|
||||
COPY . .
|
||||
RUN nimble c -d:chronicles_colors=None --threads:on -d:metrics -d:libp2p_network_protocols_metrics -d:release performance/main.nim
|
||||
|
||||
|
||||
FROM nimlang/nim:2.2.4-alpine-slim
|
||||
|
||||
WORKDIR /node
|
||||
|
||||
COPY --from=build /node/performance/main /node/main
|
||||
|
||||
RUN chmod +x main \
|
||||
&& apk add --no-cache curl iproute2
|
||||
|
||||
VOLUME ["/output"]
|
||||
|
||||
ENTRYPOINT ["./main"]
|
||||
1
performance/main.nim
Normal file
1
performance/main.nim
Normal file
@@ -0,0 +1 @@
|
||||
import ./scenarios
|
||||
55
performance/runner.sh
Executable file
55
performance/runner.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Create Docker network
|
||||
network="performance-test-network"
|
||||
if ! docker network inspect "$network" > /dev/null 2>&1; then
|
||||
docker network create --attachable --driver bridge "$network" > /dev/null
|
||||
fi
|
||||
|
||||
# Clean up output
|
||||
output_dir="$(pwd)/performance/output"
|
||||
mkdir -p "$output_dir"
|
||||
rm -rf "$output_dir"
|
||||
mkdir -p "$output_dir/sync"
|
||||
|
||||
# Run Test Nodes
|
||||
container_names=()
|
||||
PEERS=10
|
||||
for ((i = 0; i < $PEERS; i++)); do
|
||||
hostname_prefix="node-"
|
||||
hostname="$hostname_prefix$i"
|
||||
|
||||
docker run -d \
|
||||
--cap-add=NET_ADMIN \
|
||||
--name "$hostname" \
|
||||
-e NODE_ID="$i" \
|
||||
-e HOSTNAME_PREFIX="$hostname_prefix" \
|
||||
-v "$output_dir:/output" \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
--hostname="$hostname" \
|
||||
--network="$network" \
|
||||
test-node > /dev/null
|
||||
|
||||
container_names+=("$hostname")
|
||||
done
|
||||
|
||||
# Show logs in real time for all containers
|
||||
for container_name in "${container_names[@]}"; do
|
||||
docker logs -f "$container_name" &
|
||||
done
|
||||
|
||||
# Wait for all containers to finish
|
||||
for container_name in "${container_names[@]}"; do
|
||||
docker wait "$container_name" > /dev/null
|
||||
done
|
||||
|
||||
# Clean up all containers
|
||||
for container_name in "${container_names[@]}"; do
|
||||
docker rm -f "$container_name" > /dev/null
|
||||
done
|
||||
|
||||
# Remove the custom Docker network
|
||||
docker network rm "$network" > /dev/null
|
||||
exit 0
|
||||
196
performance/scenarios.nim
Normal file
196
performance/scenarios.nim
Normal file
@@ -0,0 +1,196 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.used.}
|
||||
|
||||
import metrics
|
||||
import metrics/chronos_httpserver
|
||||
import os
|
||||
import osproc
|
||||
import strformat
|
||||
import strutils
|
||||
import ../libp2p
|
||||
import ../libp2p/protocols/ping
|
||||
import ../tests/helpers
|
||||
import ./utils
|
||||
from nativesockets import getHostname
|
||||
|
||||
proc baseTest*(scenarioName = "Base test") {.async.} =
|
||||
# --- Scenario ---
|
||||
let scenario = scenarioName
|
||||
const
|
||||
nodeCount = 10
|
||||
publisherCount = 5
|
||||
peerLimit = 5
|
||||
msgCount = 100
|
||||
msgInterval = 100 # ms
|
||||
msgSize = 200 # bytes
|
||||
warmupCount = 10
|
||||
|
||||
# --- Node Setup ---
|
||||
let
|
||||
hostnamePrefix = getEnv("HOSTNAME_PREFIX", "unknown")
|
||||
nodeId = parseInt(getEnv("NODE_ID", "0"))
|
||||
hostname = getHostname()
|
||||
rng = libp2p.newRng()
|
||||
|
||||
if nodeId == 0:
|
||||
clearSyncFiles()
|
||||
|
||||
# --- Collect docker stats for one publishing and one non-publishing node ---
|
||||
var dockerStatsProc: Process = nil
|
||||
if nodeId == 0 or nodeId == publisherCount + 1:
|
||||
let dockerStatsLogPath = getDockerStatsLogPath(scenario, nodeId)
|
||||
dockerStatsProc = startDockerStatsProcess(nodeId, dockerStatsLogPath)
|
||||
defer:
|
||||
dockerStatsProc.stopDockerStatsProcess()
|
||||
|
||||
let (switch, gossipSub, pingProtocol) = setupNode(nodeId, rng)
|
||||
gossipSub.setGossipSubParams()
|
||||
|
||||
var (messageHandler, receivedMessages) = createMessageHandler(nodeId)
|
||||
gossipSub.subscribe(topic, messageHandler)
|
||||
|
||||
gossipSub.addValidator([topic], defaultMessageValidator)
|
||||
|
||||
switch.mount(gossipSub)
|
||||
switch.mount(pingProtocol)
|
||||
|
||||
await switch.start()
|
||||
defer:
|
||||
await switch.stop()
|
||||
|
||||
info "Node started, synchronizing",
|
||||
scenario,
|
||||
nodeId,
|
||||
address = switch.peerInfo.addrs,
|
||||
peerId = switch.peerInfo.peerId,
|
||||
isPublisher = nodeId <= publisherCount,
|
||||
hostname = hostname
|
||||
|
||||
await syncNodes("started", nodeId, nodeCount)
|
||||
|
||||
# --- Peer Discovery & Connection ---
|
||||
var peersAddresses = resolvePeersAddresses(nodeCount, hostnamePrefix, nodeId)
|
||||
rng.shuffle(peersAddresses)
|
||||
|
||||
await connectPeers(switch, peersAddresses, peerLimit, nodeId)
|
||||
|
||||
info "Mesh populated, synchronizing",
|
||||
nodeId, meshSize = gossipSub.mesh.getOrDefault(topic).len
|
||||
|
||||
await syncNodes("mesh", nodeId, nodeCount)
|
||||
|
||||
# --- Message Publishing ---
|
||||
let sentMessages = await publishMessagesWithWarmup(
|
||||
gossipSub, warmupCount, msgCount, msgInterval, msgSize, publisherCount, nodeId
|
||||
)
|
||||
|
||||
info "Waiting for message delivery, synchronizing"
|
||||
|
||||
await syncNodes("published", nodeId, nodeCount)
|
||||
|
||||
# --- Performance summary ---
|
||||
let stats = getStats(scenario, receivedMessages[], sentMessages)
|
||||
info "Performance summary", nodeId, stats = $stats
|
||||
|
||||
let outputPath = "/output/" & hostname & ".json"
|
||||
writeResultsToJson(outputPath, scenario, stats)
|
||||
|
||||
await syncNodes("finished", nodeId, nodeCount)
|
||||
|
||||
suite "Network Performance Tests":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
asyncTest "Base Test":
|
||||
await baseTest()
|
||||
|
||||
asyncTest "Latency Test":
|
||||
const
|
||||
latency = 100
|
||||
jitter = 20
|
||||
|
||||
discard execShellCommand(
|
||||
fmt"{enableTcCommand} netem delay {latency}ms {jitter}ms distribution normal"
|
||||
)
|
||||
await baseTest(fmt"Latency {latency}ms {jitter}ms")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Packet Loss Test":
|
||||
const packetLoss = 5
|
||||
|
||||
discard execShellCommand(fmt"{enableTcCommand} netem loss {packetLoss}%")
|
||||
await baseTest(fmt"Packet Loss {packetLoss}%")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Low Bandwidth Test":
|
||||
const
|
||||
rate = "256kbit"
|
||||
burst = "8kbit"
|
||||
limit = "5000"
|
||||
|
||||
discard
|
||||
execShellCommand(fmt"{enableTcCommand} tbf rate {rate} burst {burst} limit {limit}")
|
||||
await baseTest(fmt"Low Bandwidth rate {rate} burst {burst} limit {limit}")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Packet Reorder Test":
|
||||
const
|
||||
reorderPercent = 15
|
||||
reorderCorr = 40
|
||||
delay = 2
|
||||
|
||||
discard execShellCommand(
|
||||
fmt"{enableTcCommand} netem delay {delay}ms reorder {reorderPercent}% {reorderCorr}%"
|
||||
)
|
||||
await baseTest(
|
||||
fmt"Packet Reorder {reorderPercent}% {reorderCorr}% with {delay}ms delay"
|
||||
)
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Burst Loss Test":
|
||||
const
|
||||
lossPercent = 8
|
||||
lossCorr = 30
|
||||
|
||||
discard execShellCommand(fmt"{enableTcCommand} netem loss {lossPercent}% {lossCorr}%")
|
||||
await baseTest(fmt"Burst Loss {lossPercent}% {lossCorr}%")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Duplication Test":
|
||||
const duplicatePercent = 2
|
||||
|
||||
discard execShellCommand(fmt"{enableTcCommand} netem duplicate {duplicatePercent}%")
|
||||
await baseTest(fmt"Duplication {duplicatePercent}%")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Corruption Test":
|
||||
const corruptPercent = 0.5
|
||||
|
||||
discard execShellCommand(fmt"{enableTcCommand} netem corrupt {corruptPercent}%")
|
||||
await baseTest(fmt"Corruption {corruptPercent}%")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Queue Limit Test":
|
||||
const queueLimit = 5
|
||||
|
||||
discard execShellCommand(fmt"{enableTcCommand} netem limit {queueLimit}")
|
||||
await baseTest(fmt"Queue Limit {queueLimit}")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
|
||||
asyncTest "Combined Network Conditions Test":
|
||||
discard execShellCommand(
|
||||
"tc qdisc add dev eth0 root handle 1:0 tbf rate 2mbit burst 32kbit limit 25000"
|
||||
)
|
||||
discard execShellCommand(
|
||||
"tc qdisc add dev eth0 parent 1:1 handle 10: netem delay 100ms 20ms distribution normal loss 5% 20% reorder 10% 30% duplicate 0.5% corrupt 0.05% limit 20"
|
||||
)
|
||||
await baseTest("Combined Network Conditions")
|
||||
discard execShellCommand(disableTcCommand)
|
||||
87
performance/scripts/add_plots_to_summary.nim
Normal file
87
performance/scripts/add_plots_to_summary.nim
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import algorithm
|
||||
import sequtils
|
||||
import strformat
|
||||
import strutils
|
||||
import tables
|
||||
|
||||
proc getImgUrlBase(repo: string, publishBranchName: string, plotsPath: string): string =
|
||||
&"https://raw.githubusercontent.com/{repo}/refs/heads/{publishBranchName}/{plotsPath}"
|
||||
|
||||
proc extractTestName(base: string): string =
|
||||
let parts = base.split("_")
|
||||
if parts.len >= 2:
|
||||
parts[^2]
|
||||
else:
|
||||
base
|
||||
|
||||
proc makeImgTag(imgUrl: string, width: int): string =
|
||||
&"<img src=\"{imgUrl}\" width=\"{width}\" style=\"margin-right:10px;\" />"
|
||||
|
||||
proc prepareLatencyHistoryImage(
|
||||
imgUrlBase: string, latencyHistoryFilePath: string, width: int = 600
|
||||
): string =
|
||||
let latencyImgUrl = &"{imgUrlBase}/{latencyHistoryFilePath}"
|
||||
makeImgTag(latencyImgUrl, width)
|
||||
|
||||
proc prepareDockerStatsImages(
|
||||
plotDir: string, imgUrlBase: string, branchName: string, width: int = 450
|
||||
): Table[string, seq[string]] =
|
||||
## Groups docker stats plot images by test name and returns HTML <img> tags.
|
||||
var grouped: Table[string, seq[string]]
|
||||
|
||||
for path in walkFiles(&"{plotDir}/*.png"):
|
||||
let plotFile = path.splitPath.tail
|
||||
let testName = extractTestName(plotFile)
|
||||
let imgUrl = &"{imgUrlBase}/{branchName}/{plotFile}"
|
||||
let imgTag = makeImgTag(imgUrl, width)
|
||||
discard grouped.hasKeyOrPut(testName, @[])
|
||||
grouped[testName].add(imgTag)
|
||||
|
||||
grouped
|
||||
|
||||
proc buildSummary(
|
||||
plotDir: string,
|
||||
repo: string,
|
||||
branchName: string,
|
||||
publishBranchName: string,
|
||||
plotsPath: string,
|
||||
latencyHistoryFilePath: string,
|
||||
): string =
|
||||
let imgUrlBase = getImgUrlBase(repo, publishBranchName, plotsPath)
|
||||
|
||||
var buf: seq[string]
|
||||
|
||||
# Latency History section
|
||||
buf.add("## Latency History")
|
||||
buf.add(prepareLatencyHistoryImage(imgUrlBase, latencyHistoryFilePath) & "<br>")
|
||||
buf.add("")
|
||||
|
||||
# Performance Plots section
|
||||
let grouped = prepareDockerStatsImages(plotDir, imgUrlBase, branchName)
|
||||
buf.add(&"## Performance Plots for {branchName}")
|
||||
for test in grouped.keys.toSeq().sorted():
|
||||
let imgs = grouped[test]
|
||||
buf.add(&"### {test}")
|
||||
buf.add(imgs.join(" ") & "<br>")
|
||||
|
||||
buf.join("\n")
|
||||
|
||||
proc main() =
|
||||
let summaryPath = getEnv("GITHUB_STEP_SUMMARY", "/tmp/step_summary.md")
|
||||
let repo = getEnv("GITHUB_REPOSITORY", "vacp2p/nim-libp2p")
|
||||
let branchName = getEnv("BRANCH_NAME", "")
|
||||
let publishBranchName = getEnv("PUBLISH_BRANCH_NAME", "performance_plots")
|
||||
let plotsPath = getEnv("PLOTS_PATH", "plots")
|
||||
let latencyHistoryFilePath =
|
||||
getEnv("LATENCY_HISTORY_PLOT_FILENAME", "latency_history_all_scenarios.png")
|
||||
let checkoutSubfolder = getEnv("CHECKOUT_SUBFOLDER", "subplots")
|
||||
let plotDir = &"{checkoutSubfolder}/{plotsPath}/{branchName}"
|
||||
|
||||
let summary = buildSummary(
|
||||
plotDir, repo, branchName, publishBranchName, plotsPath, latencyHistoryFilePath
|
||||
)
|
||||
writeFile(summaryPath, summary)
|
||||
echo summary
|
||||
|
||||
main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user