Compare commits

..

1 Commits

Author SHA1 Message Date
Tanguy
ae67a9b3f6 Add queued send bytes 2023-02-13 17:32:07 +01:00
219 changed files with 15000 additions and 20670 deletions

View File

@@ -1,2 +0,0 @@
# Formatted with nph 0.5.1
dc83a1e9b68f00b3be7e09febdb1a3f877321b9a

View File

@@ -2,19 +2,19 @@ name: Bumper
on:
push:
branches:
- master
- unstable
- bumper
workflow_dispatch:
jobs:
bumpProjects:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
target: [
{ repo: status-im/nimbus-eth2, branch: unstable },
{ repo: waku-org/nwaku, branch: master },
{ repo: codex-storage/nim-codex, branch: master }
{ repo: status-im/nwaku, branch: master },
{ repo: status-im/nim-codex, branch: main }
]
steps:
- name: Clone repo
@@ -23,14 +23,13 @@ jobs:
repository: ${{ matrix.target.repo }}
ref: ${{ matrix.target.branch }}
path: nbc
submodules: true
fetch-depth: 0
token: ${{ secrets.ACTIONS_GITHUB_TOKEN }}
- name: Checkout this ref
run: |
cd nbc
git submodule update --init vendor/nim-libp2p
cd vendor/nim-libp2p
cd nbc/vendor/nim-libp2p
git checkout $GITHUB_SHA
- name: Commit this bump
@@ -38,7 +37,7 @@ jobs:
cd nbc
git config --global user.email "${{ github.actor }}@users.noreply.github.com"
git config --global user.name = "${{ github.actor }}"
git commit --allow-empty -a -m "auto-bump nim-libp2p"
git commit -a -m "auto-bump nim-libp2p"
git branch -D nim-libp2p-auto-bump-${GITHUB_REF##*/} || true
git switch -c nim-libp2p-auto-bump-${GITHUB_REF##*/}
git push -f origin nim-libp2p-auto-bump-${GITHUB_REF##*/}

View File

@@ -28,7 +28,7 @@ jobs:
cpu: amd64
#- os: windows
#cpu: i386
branch: [version-1-6]
branch: [version-1-2, version-1-6]
include:
- target:
os: linux
@@ -90,25 +90,3 @@ jobs:
nim --version
nimble --version
nimble test
lint:
name: "Lint"
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 2 # In PR, has extra merge commit: ^1 = PR, ^2 = base
- name: Check nph formatting
# Pin nph to a specific version to avoid sudden style differences.
# Updating nph version should be accompanied with running the new
# version on the fluffy directory.
run: |
VERSION="v0.5.1"
ARCHIVE="nph-linux_x64.tar.gz"
curl -L "https://github.com/arnetheduck/nph/releases/download/${VERSION}/${ARCHIVE}" -o ${ARCHIVE}
tar -xzf ${ARCHIVE}
shopt -s extglob # Enable extended globbing
./nph examples libp2p tests tools *.@(nim|nims|nimble)
git diff --exit-code

View File

@@ -1,12 +0,0 @@
name: Daily
on:
schedule:
- cron: "30 6 * * *"
workflow_dispatch:
jobs:
call-multi-nim-common:
uses: ./.github/workflows/daily_common.yml
with:
nim-branch: "['version-1-6','version-2-0']"
cpu: "['amd64']"

View File

@@ -1,84 +0,0 @@
name: daily-common
on:
workflow_call:
inputs:
nim-branch:
description: 'Nim branch'
required: true
type: string
cpu:
description: 'CPU'
required: true
type: string
exclude:
description: 'Exclude matrix configurations'
required: false
type: string
default: "[]"
jobs:
delete-cache:
runs-on: ubuntu-latest
steps:
- uses: snnaplab/delete-branch-cache-action@v1
build:
needs: delete-cache
timeout-minutes: 120
strategy:
fail-fast: false
matrix:
platform:
- os: linux
builder: ubuntu-20
shell: bash
- os: macos
builder: macos-12
shell: bash
- os: windows
builder: windows-2019
shell: msys2 {0}
branch: ${{ fromJSON(inputs.nim-branch) }}
cpu: ${{ fromJSON(inputs.cpu) }}
exclude: ${{ fromJSON(inputs.exclude) }}
defaults:
run:
shell: ${{ matrix.platform.shell }}
name: '${{ matrix.platform.os }}-${{ matrix.cpu }} (Nim ${{ matrix.branch }})'
runs-on: ${{ matrix.platform.builder }}
continue-on-error: ${{ matrix.branch == 'devel' || matrix.branch == 'version-2-0' }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Nim
uses: "./.github/actions/install_nim"
with:
os: ${{ matrix.platform.os }}
shell: ${{ matrix.platform.shell }}
nim_branch: ${{ matrix.branch }}
cpu: ${{ matrix.cpu }}
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: '~1.15.5'
cache: false
- name: Install p2pd
run: |
V=1 bash scripts/build_p2pd.sh p2pdCache 124530a3
- name: Run tests
run: |
nim --version
nimble --version
nimble install -y --depsOnly
NIMFLAGS="${NIMFLAGS} --mm:refc" nimble test
if [[ "${{ matrix.branch }}" == "devel" ]]; then
echo -e "\nTesting with '--mm:orc':\n"
NIMFLAGS="${NIMFLAGS} --mm:orc" nimble test
fi

View File

@@ -1,13 +0,0 @@
name: Daily i386
on:
schedule:
- cron: "30 6 * * *"
workflow_dispatch:
jobs:
call-multi-nim-common:
uses: ./.github/workflows/daily_common.yml
with:
nim-branch: "['version-1-6','version-2-0', 'devel']"
cpu: "['i386']"
exclude: "[{'platform': {'os':'macos'}}, {'platform': {'os':'windows'}}]"

View File

@@ -1,12 +0,0 @@
name: Daily Nim Devel
on:
schedule:
- cron: "30 6 * * *"
workflow_dispatch:
jobs:
call-multi-nim-common:
uses: ./.github/workflows/daily_common.yml
with:
nim-branch: "['devel']"
cpu: "['amd64']"

View File

@@ -1,8 +1,6 @@
name: Docgen
on:
push:
branches:
- master
workflow_dispatch:
@@ -21,7 +19,7 @@ jobs:
- uses: jiro4989/setup-nim-action@v1
with:
nim-version: '1.6.x'
nim-version: 'stable'
- name: Generate doc
run: |
@@ -29,7 +27,7 @@ jobs:
nimble --version
nimble install_pinned
# nim doc can "fail", but the doc is still generated
nim doc --git.url:https://github.com/vacp2p/nim-libp2p --git.commit:${GITHUB_REF##*/} --outdir:${GITHUB_REF##*/} --project libp2p || true
nim doc --git.url:https://github.com/status-im/nim-libp2p --git.commit:${GITHUB_REF##*/} --outdir:${GITHUB_REF##*/} --project libp2p || true
# check that the folder exists
ls ${GITHUB_REF##*/}
@@ -37,7 +35,7 @@ jobs:
- name: Clone the gh-pages branch
uses: actions/checkout@v2
with:
repository: vacp2p/nim-libp2p
repository: status-im/nim-libp2p
ref: gh-pages
path: subdoc
submodules: true
@@ -47,6 +45,9 @@ jobs:
run: |
cd subdoc
# Delete merged branches doc's
for branch in $(git branch -vv | grep ': gone]' | awk '{print $1}'); do rm -rf $branch; done
# Update / create this branch doc
rm -rf ${GITHUB_REF##*/}
mv ../${GITHUB_REF##*/} .
@@ -62,6 +63,7 @@ jobs:
git push origin gh-pages
update_site:
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/docs'
name: 'Rebuild website'
runs-on: ubuntu-latest
steps:
@@ -77,12 +79,12 @@ jobs:
nim-version: 'stable'
- name: Generate website
run: pip install mkdocs-material && nimble -y website
run: pip install mkdocs-material && nimble website
- name: Clone the gh-pages branch
uses: actions/checkout@v2
with:
repository: vacp2p/nim-libp2p
repository: status-im/nim-libp2p
ref: gh-pages
path: subdoc
fetch-depth: 0
@@ -91,21 +93,11 @@ jobs:
run: |
cd subdoc
# Ensure the latest changes are fetched and reset to the remote branch
git fetch origin gh-pages
git reset --hard origin/gh-pages
rm -rf docs
mv ../site docs
git add .
if git diff-index --quiet HEAD --; then
echo "No changes to commit"
else
git config --global user.email "${{ github.actor }}@users.noreply.github.com"
git config --global user.name "${{ github.actor }}"
git commit -m "update website"
git push origin gh-pages
fi
git config --global user.email "${{ github.actor }}@users.noreply.github.com"
git config --global user.name = "${{ github.actor }}"
git commit -a -m "update website"
git push origin gh-pages

View File

@@ -1,40 +0,0 @@
name: Interoperability Testing
on:
pull_request:
push:
branches:
- unstable
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
run-transport-interop:
name: Run transport interoperability tests
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- name: Build image
run: docker buildx build --load -t nim-libp2p-head -f tests/transport-interop/Dockerfile .
- name: Run tests
uses: libp2p/test-plans/.github/actions/run-transport-interop-test@master
with:
test-filter: nim-libp2p-head
extra-versions: ${{ github.workspace }}/tests/transport-interop/version.json
run-hole-punching-interop:
name: Run hole-punching interoperability tests
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- name: Build image
run: docker buildx build --load -t nim-libp2p-head -f tests/hole-punching-interop/Dockerfile .
- name: Run tests
uses: libp2p/test-plans/.github/actions/run-interop-hole-punch-test@master
with:
test-filter: nim-libp2p-head
extra-versions: ${{ github.workspace }}/tests/hole-punching-interop/version.json

82
.github/workflows/multi_nim.yml vendored Normal file
View File

@@ -0,0 +1,82 @@
name: Daily
on:
schedule:
- cron: "30 6 * * *"
workflow_dispatch:
jobs:
delete-cache:
runs-on: ubuntu-latest
steps:
- uses: snnaplab/delete-branch-cache-action@v1
build:
needs: delete-cache
timeout-minutes: 120
strategy:
fail-fast: false
matrix:
target:
- os: linux
cpu: amd64
- os: linux
cpu: i386
- os: macos
cpu: amd64
- os: windows
cpu: amd64
#- os: windows
#cpu: i386
branch: [version-1-2, version-1-6, devel]
include:
- target:
os: linux
builder: ubuntu-20.04
shell: bash
- target:
os: macos
builder: macos-12
shell: bash
- target:
os: windows
builder: windows-2019
shell: msys2 {0}
defaults:
run:
shell: ${{ matrix.shell }}
name: '${{ matrix.target.os }}-${{ matrix.target.cpu }} (Nim ${{ matrix.branch }})'
runs-on: ${{ matrix.builder }}
continue-on-error: ${{ matrix.branch == 'devel' }}
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Setup Nim
uses: "./.github/actions/install_nim"
with:
os: ${{ matrix.target.os }}
shell: ${{ matrix.shell }}
nim_branch: ${{ matrix.branch }}
cpu: ${{ matrix.target.cpu }}
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: '~1.15.5'
- name: Install p2pd
run: |
V=1 bash scripts/build_p2pd.sh p2pdCache 124530a3
- name: Run tests
run: |
nim --version
nimble --version
nimble install -y --depsOnly
NIMFLAGS="${NIMFLAGS} --gc:refc" nimble test
if [[ "${{ matrix.branch }}" == "devel" ]]; then
echo -e "\nTesting with '--gc:orc':\n"
NIMFLAGS="${NIMFLAGS} --gc:orc" nimble test
fi

1
.gitignore vendored
View File

@@ -16,4 +16,3 @@ tests/pubsub/testgossipsub
examples/*.md
nimble.develop
nimble.paths
go-libp2p-daemon/

29
.pinned
View File

@@ -1,17 +1,16 @@
bearssl;https://github.com/status-im/nim-bearssl@#e4157639db180e52727712a47deaefcbbac6ec86
bearssl;https://github.com/status-im/nim-bearssl@#a647994910904b0103a05db3a5ec1ecfc4d91a88
chronicles;https://github.com/status-im/nim-chronicles@#32ac8679680ea699f7dbc046e8e0131cac97d41a
chronos;https://github.com/status-im/nim-chronos@#672db137b7cad9b384b8f4fb551fb6bbeaabfe1b
dnsclient;https://github.com/ba0f3/dnsclient.nim@#23214235d4784d24aceed99bbfe153379ea557c8
faststreams;https://github.com/status-im/nim-faststreams@#720fc5e5c8e428d9d0af618e1e27c44b42350309
httputils;https://github.com/status-im/nim-http-utils@#3b491a40c60aad9e8d3407443f46f62511e63b18
json_serialization;https://github.com/status-im/nim-json-serialization@#85b7ea093cb85ee4f433a617b97571bd709d30df
metrics;https://github.com/status-im/nim-metrics@#6142e433fc8ea9b73379770a788017ac528d46ff
nimcrypto;https://github.com/cheatfate/nimcrypto@#1c8d6e3caf3abc572136ae9a1da81730c4eb4288
results;https://github.com/arnetheduck/nim-results@#f3c666a272c69d70cb41e7245e7f6844797303ad
secp256k1;https://github.com/status-im/nim-secp256k1@#7246d91c667f4cc3759fdd50339caa45a2ecd8be
serialization;https://github.com/status-im/nim-serialization@#4bdbc29e54fe54049950e352bb969aab97173b35
stew;https://github.com/status-im/nim-stew@#3159137d9a3110edb4024145ce0ba778975de40e
chronos;https://github.com/status-im/nim-chronos@#75d030ff71264513fb9701c75a326cd36fcb4692
dnsclient;https://github.com/ba0f3/dnsclient.nim@#fcd7443634b950eaea574e5eaa00a628ae029823
faststreams;https://github.com/status-im/nim-faststreams@#b42daf41d8eb4fbce40add6836bed838f8d85b6f
httputils;https://github.com/status-im/nim-http-utils@#a85bd52ae0a956983ca6b3267c72961d2ec0245f
json_serialization;https://github.com/status-im/nim-json-serialization@#a7d815ed92f200f490c95d3cfd722089cc923ce6
metrics;https://github.com/status-im/nim-metrics@#21e99a2e9d9f80e68bef65c80ef781613005fccb
nimcrypto;https://github.com/cheatfate/nimcrypto@#24e006df85927f64916e60511620583b11403178
secp256k1;https://github.com/status-im/nim-secp256k1@#fd173fdff863ce2e211cf64c9a03bc7539fe40b0
serialization;https://github.com/status-im/nim-serialization@#d77417cba6896c26287a68e6a95762e45a1b87e5
stew;https://github.com/status-im/nim-stew@#7184d2424dc3945657884646a72715d494917aad
testutils;https://github.com/status-im/nim-testutils@#dfc4c1b39f9ded9baf6365014de2b4bfb4dafc34
unittest2;https://github.com/status-im/nim-unittest2@#2300fa9924a76e6c96bc4ea79d043e3a0f27120c
websock;https://github.com/status-im/nim-websock@#f8ed9b40a5ff27ad02a3c237c4905b0924e3f982
zlib;https://github.com/status-im/nim-zlib@#38b72eda9d70067df4a953f56b5ed59630f2a17b
unittest2;https://github.com/status-im/nim-unittest2@#da8398c45cafd5bd7772da1fc96e3924a18d3823
websock;https://github.com/status-im/nim-websock@#691f069b209d372b1240d5ae1f57fb7bbafeaba7
zlib;https://github.com/status-im/nim-zlib@#6a6670afba6b97b29b920340e2641978c05ab4d8

View File

@@ -5,8 +5,8 @@
<h3 align="center">The <a href="https://nim-lang.org/">Nim</a> implementation of the <a href="https://libp2p.io/">libp2p</a> Networking Stack.</h3>
<p align="center">
<a href="https://github.com/vacp2p/nim-libp2p/actions"><img src="https://github.com/vacp2p/nim-libp2p/actions/workflows/ci.yml/badge.svg" /></a>
<a href="https://codecov.io/gh/vacp2p/nim-libp2p"><img src="https://codecov.io/gh/vacp2p/nim-libp2p/branch/master/graph/badge.svg?token=UR5JRQ249W"/></a>
<a href="https://github.com/status-im/nim-libp2p/actions"><img src="https://github.com/status-im/nim-libp2p/actions/workflows/ci.yml/badge.svg" /></a>
<a href="https://codecov.io/gh/status-im/nim-libp2p"><img src="https://codecov.io/gh/status-im/nim-libp2p/branch/master/graph/badge.svg?token=UR5JRQ249W"/></a>
</p>
@@ -20,7 +20,6 @@
- [Background](#background)
- [Install](#install)
- [Getting Started](#getting-started)
- [Go-libp2p-daemon](#go-libp2p-daemon)
- [Modules](#modules)
- [Users](#users)
- [Stability](#stability)
@@ -41,20 +40,18 @@ Learn more about libp2p at [**libp2p.io**](https://libp2p.io) and follow libp2p'
## Install
**Prerequisite**
- [Nim](https://nim-lang.org/install.html)
> The currently supported Nim version is 1.6.18.
```
nimble install libp2p
```
## Getting Started
You'll find the nim-libp2p documentation [here](https://vacp2p.github.io/nim-libp2p/docs/).
You'll find the nim-libp2p documentation [here](https://status-im.github.io/nim-libp2p/docs/).
### Testing
Remember you'll need to build the `go-libp2p-daemon` binary to run the `nim-libp2p` tests.
To do so, please follow the installation instructions in [daemonapi.md](examples/go-daemon/daemonapi.md).
**Go Daemon:**
Please find the installation and usage intructions in [daemonapi.md](examples/go-daemon/daemonapi.md).
## Modules
List of packages modules implemented in nim-libp2p:
| Name | Description |
@@ -71,6 +68,7 @@ List of packages modules implemented in nim-libp2p:
| [libp2p-ws](libp2p/transports/wstransport.nim) | WebSocket & WebSocket Secure transport |
| [libp2p-tor](libp2p/transports/tortransport.nim) | Tor Transport |
| **Secure Channels** | |
| [libp2p-secio](libp2p/protocols/secure/secio.nim) | Secio secure channel |
| [libp2p-noise](libp2p/protocols/secure/noise.nim) | [Noise](https://docs.libp2p.io/concepts/secure-comm/noise/) secure channel |
| [libp2p-plaintext](libp2p/protocols/secure/plaintext.nim) | Plain Text for development purposes |
| **Stream Multiplexers** | |
@@ -95,8 +93,8 @@ List of packages modules implemented in nim-libp2p:
nim-libp2p is used by:
- [Nimbus](https://github.com/status-im/nimbus-eth2), an Ethereum client
- [nwaku](https://github.com/waku-org/nwaku), a decentralized messaging application
- [nim-codex](https://github.com/codex-storage/nim-codex), a decentralized storage application
- [nwaku](https://github.com/status-im/nwaku), a decentralized messaging application
- [nim-codex](https://github.com/status-im/nim-codex), a decentralized storage application
- (open a pull request if you want to be included here)
## Stability
@@ -107,12 +105,12 @@ The versioning follows [semver](https://semver.org/), with some additions:
- Some of libp2p procedures are marked as `.public.`, they will remain compatible during each `MAJOR` version
- The rest of the procedures are considered internal, and can change at any `MINOR` version (but remain compatible for each new `PATCH`)
We aim to be compatible at all time with at least 2 Nim `MINOR` versions, currently `1.6 & 2.0`
We aim to be compatible at all time with at least 2 Nim `MINOR` versions, currently `1.2 & 1.6`
## Development
Clone and Install dependencies:
```sh
git clone https://github.com/vacp2p/nim-libp2p
git clone https://github.com/status-im/nim-libp2p
cd nim-libp2p
# to use dependencies computed by nimble
nimble install -dy
@@ -134,12 +132,11 @@ The libp2p implementation in Nim is a work in progress. We welcome contributors
- Go through the modules and **check out existing issues**. This would be especially useful for modules in active development. Some knowledge of IPFS/libp2p may be required, as well as the infrastructure behind it.
- **Perform code reviews**. Feel free to let us know if you found anything that can a) speed up the project development b) ensure better quality and c) reduce possible future bugs.
- **Add tests**. Help nim-libp2p to be more robust by adding more tests to the [tests folder](tests/).
- **Small PRs**. Try to keep PRs atomic and digestible. This makes the review process and pinpointing bugs easier.
- **Code format**. Please format code using [nph](https://github.com/arnetheduck/nph) v0.5.1. This will ensure a consistent codebase and make PRs easier to review. A CI rule has been added to ensure that future commits are all formatted using the same nph version.
The code follows the [Status Nim Style Guide](https://status-im.github.io/nim-style-guide/).
### Contributors
<a href="https://github.com/vacp2p/nim-libp2p/graphs/contributors"><img src="https://contrib.rocks/image?repo=vacp2p/nim-libp2p" alt="nim-libp2p contributors"></a>
<a href="https://github.com/status-im/nim-libp2p/graphs/contributors"><img src="https://contrib.rocks/image?repo=status-im/nim-libp2p" alt="nim-libp2p contributors"></a>
### Core Maintainers
<table>

View File

@@ -7,17 +7,17 @@ if dirExists("nimbledeps/pkgs2"):
switch("warning", "CaseTransition:off")
switch("warning", "ObservableStores:off")
switch("warning", "LockLevel:off")
--styleCheck:
usages
switch("warningAsError", "UseBase:on")
--styleCheck:
error
--define:chronosStrictException
--styleCheck:usages
if (NimMajor, NimMinor) < (1, 6):
--styleCheck:hint
else:
--styleCheck:error
# Avoid some rare stack corruption while using exceptions with a SEH-enabled
# toolchain: https://github.com/status-im/nimbus-eth2/issues/3121
if defined(windows) and not defined(vcc):
--define:
nimRawSetjmp
--define:nimRawSetjmp
# begin Nimble config (version 1)
when fileExists("nimble.paths"):

View File

@@ -1,25 +1,25 @@
## # Circuit Relay example
##
## Circuit Relay can be used when a node cannot reach another node
## directly, but can reach it through another node (the Relay).
## directly, but can reach it through a another node (the Relay).
##
## That may happen because of NAT, Firewalls, or incompatible transports.
##
## More informations [here](https://docs.libp2p.io/concepts/circuit-relay/).
import chronos, stew/byteutils
import libp2p, libp2p/protocols/connectivity/relay/[relay, client]
import libp2p,
libp2p/protocols/connectivity/relay/[relay, client]
# Helper to create a circuit relay node
proc createCircuitRelaySwitch(r: Relay): Switch =
SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
.withTcpTransport()
.withMplex()
.withNoise()
.withCircuitRelay(r)
.build()
SwitchBuilder.new()
.withRng(newRng())
.withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ])
.withTcpTransport()
.withMplex()
.withNoise()
.withCircuitRelay(r)
.build()
proc main() {.async.} =
# Create a custom protocol
@@ -56,11 +56,8 @@ proc main() {.async.} =
let
# Create a relay address to swDst using swRel as the relay
addrs = MultiAddress
.init(
$swRel.peerInfo.addrs[0] & "/p2p/" & $swRel.peerInfo.peerId & "/p2p-circuit"
)
.get()
addrs = MultiAddress.init($swRel.peerInfo.addrs[0] & "/p2p/" &
$swRel.peerInfo.peerId & "/p2p-circuit").get()
# Connect Dst to the relay
await swDst.connect(swRel.peerInfo.peerId, swRel.peerInfo.addrs)
@@ -69,7 +66,7 @@ proc main() {.async.} =
let rsvp = await clDst.reserve(swRel.peerInfo.peerId, swRel.peerInfo.addrs)
# Src dial Dst using the relay
let conn = await swSrc.dial(swDst.peerInfo.peerId, @[addrs], customProtoCodec)
let conn = await swSrc.dial(swDst.peerInfo.peerId, @[ addrs ], customProtoCodec)
await conn.writeLp("test1")
var msg = string.fromBytes(await conn.readLp(1024))

View File

@@ -1,12 +1,15 @@
when not (compileOption("threads")):
when not(compileOption("threads")):
{.fatal: "Please, compile this program with the --threads:on option!".}
import strformat, strutils, stew/byteutils, chronos, libp2p
import
strformat, strutils,
stew/byteutils,
chronos,
libp2p
const DefaultAddr = "/ip4/127.0.0.1/tcp/0"
const Help =
"""
const Help = """
Commands: /[?|help|connect|disconnect|exit]
help: Prints this help
connect: dials a remote peer
@@ -14,11 +17,12 @@ const Help =
exit: closes the chat
"""
type Chat = ref object
switch: Switch # a single entry point for dialing and listening to peer
stdinReader: StreamTransport # transport streams between read & write file descriptor
conn: Connection # connection to the other peer
connected: bool # if the node is connected to another peer
type
Chat = ref object
switch: Switch # a single entry point for dialing and listening to peer
stdinReader: StreamTransport # transport streams between read & write file descriptor
conn: Connection # connection to the other peer
connected: bool # if the node is connected to another peer
##
# Stdout helpers, to write the prompt
@@ -37,7 +41,8 @@ proc writeStdout(c: Chat, str: string) =
##
const ChatCodec = "/nim-libp2p/chat/1.0.0"
type ChatProto = ref object of LPProtocol
type
ChatProto = ref object of LPProtocol
proc new(T: typedesc[ChatProto], c: Chat): T =
let chatproto = T()
@@ -72,9 +77,9 @@ proc handlePeer(c: Chat, conn: Connection) {.async.} =
strData = await conn.readLp(1024)
str = string.fromBytes(strData)
c.writeStdout $conn.peerId & ": " & $str
except LPStreamEOFError:
defer:
c.writeStdout $conn.peerId & " disconnected"
defer: c.writeStdout $conn.peerId & " disconnected"
await c.conn.close()
c.connected = false
@@ -83,7 +88,10 @@ proc dialPeer(c: Chat, address: string) {.async.} =
let
multiAddr = MultiAddress.init(address).tryGet()
# split the peerId part /p2p/...
peerIdBytes = multiAddr[multiCodec("p2p")].tryGet().protoAddress().tryGet()
peerIdBytes = multiAddr[multiCodec("p2p")]
.tryGet()
.protoAddress()
.tryGet()
remotePeer = PeerId.init(peerIdBytes).tryGet()
# split the wire address
ip4Addr = multiAddr[multiCodec("ip4")].tryGet()
@@ -116,6 +124,7 @@ proc readLoop(c: Chat) {.async.} =
let address = await c.stdinReader.readLine()
if address.len > 0:
await c.dialPeer(address)
elif line.startsWith("/exit"):
if c.connected and c.conn.closed.not:
await c.conn.close()
@@ -162,18 +171,16 @@ proc main() {.async.} =
var switch = SwitchBuilder
.new()
.withRng(rng)
# Give the application RNG
.withRng(rng) # Give the application RNG
.withAddress(localAddress)
.withTcpTransport()
# Use TCP as transport
.withMplex()
# Use Mplex as muxer
.withNoise()
# Use Noise as secure manager
.withTcpTransport() # Use TCP as transport
.withMplex() # Use Mplex as muxer
.withNoise() # Use Noise as secure manager
.build()
let chat = Chat(switch: switch, stdinReader: stdinReader)
let chat = Chat(
switch: switch,
stdinReader: stdinReader)
switch.mount(ChatProto.new(chat))

View File

@@ -2,7 +2,8 @@ import chronos, nimcrypto, strutils
import ../../libp2p/daemon/daemonapi
import ../hexdump
const PubSubTopic = "test-net"
const
PubSubTopic = "test-net"
proc dumpSubscribedPeers(api: DaemonAPI) {.async.} =
var peers = await api.pubsubListPeers(PubSubTopic)
@@ -36,12 +37,12 @@ proc main() {.async.} =
asyncSpawn monitor(api)
proc pubsubLogger(
api: DaemonAPI, ticket: PubsubTicket, message: PubSubMessage
): Future[bool] {.async.} =
proc pubsubLogger(api: DaemonAPI,
ticket: PubsubTicket,
message: PubSubMessage): Future[bool] {.async.} =
let msglen = len(message.data)
echo "= Recieved pubsub message with length ",
msglen, " bytes from peer ", message.peer.pretty()
echo "= Recieved pubsub message with length ", msglen,
" bytes from peer ", message.peer.pretty()
echo dumpHex(message.data)
await api.dumpSubscribedPeers()
result = true

View File

@@ -2,16 +2,18 @@ import chronos, nimcrypto, strutils
import ../../libp2p/daemon/daemonapi
## nim c -r --threads:on chat.nim
when not (compileOption("threads")):
when not(compileOption("threads")):
{.fatal: "Please, compile this program with the --threads:on option!".}
const ServerProtocols = @["/test-chat-stream"]
const
ServerProtocols = @["/test-chat-stream"]
type CustomData = ref object
api: DaemonAPI
remotes: seq[StreamTransport]
consoleFd: AsyncFD
serveFut: Future[void]
type
CustomData = ref object
api: DaemonAPI
remotes: seq[StreamTransport]
consoleFd: AsyncFD
serveFut: Future[void]
proc threadMain(wfd: AsyncFD) {.thread.} =
## This procedure performs reading from `stdin` and sends data over
@@ -80,7 +82,7 @@ proc serveThread(udata: CustomData) {.async.} =
relay = true
break
if relay:
echo peer.pretty(), " * ", " [", addresses.join(", "), "]"
echo peer.pretty(), " * ", " [", addresses.join(", "), "]"
else:
echo peer.pretty(), " [", addresses.join(", "), "]"
elif line.startsWith("/exit"):

View File

@@ -1,8 +1,6 @@
# Table of Contents
- [Introduction](#introduction)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Script](#script)
- [Usage](#usage)
- [Example](#example)
- [Getting Started](#getting-started)
@@ -10,29 +8,26 @@
# Introduction
This is a libp2p-backed daemon wrapping the functionalities of go-libp2p for use in Nim. <br>
For more information about the go daemon, check out [this repository](https://github.com/libp2p/go-libp2p-daemon).
> **Required only** for running the tests.
# Prerequisites
Go with version `1.15.15`.
> You will *likely* be able to build `go-libp2p-daemon` with different Go versions, but **they haven't been tested**.
# Installation
Follow one of the methods below:
## Script
Run the build script while having the `go` command pointing to the correct Go version.
We recommend using `1.15.15`, as previously stated.
```sh
./scripts/build_p2pd.sh
```
If everything goes correctly, the binary (`p2pd`) should be built and placed in the correct directory.
If you find any issues, please head into our discord and ask for our assistance.
# clone and install dependencies
git clone https://github.com/status-im/nim-libp2p
cd nim-libp2p
nimble install
After successfully building the binary, remember to add it to your path so it can be found. You can do that by running:
```sh
export PATH="$PATH:$HOME/go/bin"
# perform unit tests
nimble test
# update the git submodule to install the go daemon
git submodule update --init --recursive
go version
git clone https://github.com/libp2p/go-libp2p-daemon
cd go-libp2p-daemon
git checkout v0.0.1
go install ./...
cd ..
```
> **Tip:** To make this change permanent, add the command above to your `.bashrc` file.
# Usage

View File

@@ -1,25 +1,26 @@
import chronos, nimcrypto, strutils, os
import ../../libp2p/daemon/daemonapi
const PubSubTopic = "test-net"
const
PubSubTopic = "test-net"
proc main(bn: string) {.async.} =
echo "= Starting P2P node"
var bootnodes = bn.split(",")
var api = await newDaemonApi(
{DHTFull, PSGossipSub, WaitBootstrap}, bootstrapNodes = bootnodes, peersRequired = 1
)
var api = await newDaemonApi({DHTFull, PSGossipSub, WaitBootstrap},
bootstrapNodes = bootnodes,
peersRequired = 1)
var id = await api.identity()
echo "= P2P node ", id.peer.pretty(), " started:"
for item in id.addresses:
echo item
proc pubsubLogger(
api: DaemonAPI, ticket: PubsubTicket, message: PubSubMessage
): Future[bool] {.async.} =
proc pubsubLogger(api: DaemonAPI,
ticket: PubsubTicket,
message: PubSubMessage): Future[bool] {.async.} =
let msglen = len(message.data)
echo "= Recieved pubsub message with length ",
msglen, " bytes from peer ", message.peer.pretty(), ": "
echo "= Recieved pubsub message with length ", msglen,
" bytes from peer ", message.peer.pretty(), ": "
var strdata = cast[string](message.data)
echo strdata
result = true

View File

@@ -1,5 +1,5 @@
import chronos # an efficient library for async
import stew/byteutils # various utils
import chronos # an efficient library for async
import stew/byteutils # various utils
import libp2p
##
@@ -7,18 +7,20 @@ import libp2p
##
const TestCodec = "/test/proto/1.0.0" # custom protocol string identifier
type TestProto = ref object of LPProtocol # declare a custom protocol
type
TestProto = ref object of LPProtocol # declare a custom protocol
proc new(T: typedesc[TestProto]): T =
# every incoming connections will be in handled in this closure
proc handle(conn: Connection, proto: string) {.async.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
echo "Got from remote - ", string.fromBytes(await conn.readLp(1024))
await conn.writeLp("Roger p2p!")
# We must close the connections ourselves when we're done with it
await conn.close()
return T.new(codecs = @[TestCodec], handler = handle)
return T(codecs: @[TestCodec], handler: handle)
##
# Helper to create a switch/node
@@ -26,16 +28,11 @@ proc new(T: typedesc[TestProto]): T =
proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
var switch = SwitchBuilder
.new()
.withRng(rng)
# Give the application RNG
.withAddress(ma)
# Our local address(es)
.withTcpTransport()
# Use TCP as transport
.withMplex()
# Use Mplex as muxer
.withNoise()
# Use Noise as secure manager
.withRng(rng) # Give the application RNG
.withAddress(ma) # Our local address(es)
.withTcpTransport() # Use TCP as transport
.withMplex() # Use Mplex as muxer
.withNoise() # Use Noise as secure manager
.build()
result = switch
@@ -43,7 +40,7 @@ proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
##
# The actual application
##
proc main() {.async.} =
proc main() {.async, gcsafe.} =
let
rng = newRng() # Single random number source for the whole application
# port 0 will take a random available port
@@ -76,8 +73,7 @@ proc main() {.async.} =
# use the second node to dial the first node
# using the first node peerid and address
# and specify our custom protocol codec
let conn =
await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, TestCodec)
let conn = await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, TestCodec)
# conn is now a fully setup connection, we talk directly to the node1 custom protocol handler
await conn.writeLp("Hello p2p!") # writeLp send a length prefixed buffer over the wire
@@ -88,7 +84,6 @@ proc main() {.async.} =
# We must close the connection ourselves when we're done with it
await conn.close()
await allFutures(switch1.stop(), switch2.stop())
# close connections and shutdown all transports
await allFutures(switch1.stop(), switch2.stop()) # close connections and shutdown all transports
waitFor(main())

View File

@@ -35,24 +35,26 @@ proc dumpHex*(pbytes: pointer, nbytes: int, items = 1, ascii = true): string =
var asciiText = ""
while i < nbytes:
if i %% 16 == 0:
result = result & toHex(cast[BiggestInt](slider), sizeof(BiggestInt) * 2) & ": "
result = result & toHex(cast[BiggestInt](slider),
sizeof(BiggestInt) * 2) & ": "
var k = 0
while k < items:
var ch = cast[ptr char](cast[uint](slider) + k.uint)[]
if ord(ch) > 31 and ord(ch) < 127:
asciiText &= ch
else:
asciiText &= "."
if ord(ch) > 31 and ord(ch) < 127: asciiText &= ch else: asciiText &= "."
inc(k)
case items
case items:
of 1:
result = result & toHex(cast[BiggestInt](cast[ptr uint8](slider)[]), hexSize)
result = result & toHex(cast[BiggestInt](cast[ptr uint8](slider)[]),
hexSize)
of 2:
result = result & toHex(cast[BiggestInt](cast[ptr uint16](slider)[]), hexSize)
result = result & toHex(cast[BiggestInt](cast[ptr uint16](slider)[]),
hexSize)
of 4:
result = result & toHex(cast[BiggestInt](cast[ptr uint32](slider)[]), hexSize)
result = result & toHex(cast[BiggestInt](cast[ptr uint32](slider)[]),
hexSize)
of 8:
result = result & toHex(cast[BiggestInt](cast[ptr uint64](slider)[]), hexSize)
result = result & toHex(cast[BiggestInt](cast[ptr uint64](slider)[]),
hexSize)
else:
raise newException(ValueError, "Wrong items size!")
result = result & " "

View File

@@ -34,20 +34,15 @@ import libp2p/protocols/ping
## [chronos](https://github.com/status-im/nim-chronos) the asynchronous framework used by `nim-libp2p`
##
## Next, we'll create a helper procedure to create our switches. A switch needs a bit of configuration, and it will be easier to do this configuration only once:
## Next, we'll create an helper procedure to create our switches. A switch needs a bit of configuration, and it will be easier to do this configuration only once:
proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
var switch = SwitchBuilder
.new()
.withRng(rng)
# Give the application RNG
.withAddress(ma)
# Our local address(es)
.withTcpTransport()
# Use TCP as transport
.withMplex()
# Use Mplex as muxer
.withNoise()
# Use Noise as secure manager
.withRng(rng) # Give the application RNG
.withAddress(ma) # Our local address(es)
.withTcpTransport() # Use TCP as transport
.withMplex() # Use Mplex as muxer
.withNoise() # Use Noise as secure manager
.build()
return switch
@@ -58,11 +53,11 @@ proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
##
##
## Let's now start to create our main procedure:
proc main() {.async.} =
proc main() {.async, gcsafe.} =
let
rng = newRng()
localAddress = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
pingProtocol = Ping.new(rng = rng)
pingProtocol = Ping.new(rng=rng)
## We created some variables that we'll need for the rest of the application: the global `rng` instance, our `localAddress`, and an instance of the `Ping` protocol.
## The address is in the [MultiAddress](https://github.com/multiformats/multiaddr) format. The port `0` means "take any port available".
##
@@ -82,9 +77,8 @@ proc main() {.async.} =
## Now that we've started the nodes, they are listening for incoming peers.
## We can find out which port was attributed, and the resulting local addresses, by using `switch1.peerInfo.addrs`.
##
## We'll **dial** the first switch from the second one, by specifying its **Peer ID**, its **MultiAddress** and the **`Ping` protocol codec**:
let conn =
await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, PingCodec)
## We'll **dial** the first switch from the second one, by specifying it's **Peer ID**, it's **MultiAddress** and the **`Ping` protocol codec**:
let conn = await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, PingCodec)
## We now have a `Ping` connection setup between the second and the first switch, we can use it to actually ping the node:
# ping the other node and echo the ping duration
echo "ping: ", await pingProtocol.ping(conn)
@@ -92,8 +86,7 @@ proc main() {.async.} =
# We must close the connection ourselves when we're done with it
await conn.close()
## And that's it! Just a little bit of cleanup: shutting down the switches, waiting for them to stop, and we'll call our `main` procedure:
await allFutures(switch1.stop(), switch2.stop())
# close connections and shutdown all transports
await allFutures(switch1.stop(), switch2.stop()) # close connections and shutdown all transports
waitFor(main())

View File

@@ -18,14 +18,14 @@ type TestProto = ref object of LPProtocol
## We've set a [protocol ID](https://docs.libp2p.io/concepts/protocols/#protocol-ids), and created a custom `LPProtocol`. In a more complex protocol, we could use this structure to store interesting variables.
##
## A protocol generally has two parts: a handling/server part, and a dialing/client part.
## These two parts can be identical, but in our trivial protocol, the server will wait for a message from the client, and the client will send a message, so we have to handle the two cases separately.
## A protocol generally has two part: and handling/server part, and a dialing/client part.
## Theses two parts can be identical, but in our trivial protocol, the server will wait for a message from the client, and the client will send a message, so we have to handle the two cases separately.
##
## Let's start with the server part:
proc new(T: typedesc[TestProto]): T =
# every incoming connections will in be handled in this closure
proc handle(conn: Connection, proto: string) {.async.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
# Read up to 1024 bytes from this connection, and transform them into
# a string
echo "Got from remote - ", string.fromBytes(await conn.readLp(1024))
@@ -41,31 +41,29 @@ proc new(T: typedesc[TestProto]): T =
proc hello(p: TestProto, conn: Connection) {.async.} =
await conn.writeLp("Hello p2p!")
## Again, pretty straightforward, we just send a message on the connection.
## Again, pretty straight-forward, we just send a message on the connection.
##
## We can now create our main procedure:
proc main() {.async.} =
proc main() {.async, gcsafe.} =
let
rng = newRng()
testProto = TestProto.new()
switch1 = newStandardSwitch(rng = rng)
switch2 = newStandardSwitch(rng = rng)
switch1 = newStandardSwitch(rng=rng)
switch2 = newStandardSwitch(rng=rng)
switch1.mount(testProto)
await switch1.start()
await switch2.start()
let conn =
await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, TestCodec)
let conn = await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, TestCodec)
await testProto.hello(conn)
# We must close the connection ourselves when we're done with it
await conn.close()
await allFutures(switch1.stop(), switch2.stop())
# close connections and shutdown all transports
await allFutures(switch1.stop(), switch2.stop()) # close connections and shutdown all transports
## This is very similar to the first tutorial's `main`, the only noteworthy difference is that we use `newStandardSwitch`, which is similar to the `createSwitch` of the first tutorial, but is bundled directly in libp2p
##

View File

@@ -57,8 +57,8 @@ proc decode(_: type Metric, buf: seq[byte]): Result[Metric, ProtoError] =
#
# We are just checking the error, and ignoring whether the value
# is present or not (default values are valid).
discard ?pb.getField(1, res.name)
discard ?pb.getField(2, res.value)
discard ? pb.getField(1, res.name)
discard ? pb.getField(2, res.value)
ok(res)
proc encode(m: MetricList): ProtoBuffer =
@@ -72,10 +72,10 @@ proc decode(_: type MetricList, buf: seq[byte]): Result[MetricList, ProtoError]
res: MetricList
metrics: seq[seq[byte]]
let pb = initProtoBuffer(buf)
discard ?pb.getRepeatedField(1, metrics)
discard ? pb.getRepeatedField(1, metrics)
for metric in metrics:
res.metrics &= ?Metric.decode(metric)
res.metrics &= ? Metric.decode(metric)
ok(res)
## ## Results instead of exceptions
@@ -102,13 +102,13 @@ proc decode(_: type MetricList, buf: seq[byte]): Result[MetricList, ProtoError]
## ## Creating the protocol
## We'll next create a protocol, like in the last tutorial, to request these metrics from our host
type
MetricCallback = proc(): Future[MetricList] {.raises: [], gcsafe.}
MetricCallback = proc: Future[MetricList] {.raises: [], gcsafe.}
MetricProto = ref object of LPProtocol
metricGetter: MetricCallback
proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
var res: MetricProto
proc handle(conn: Connection, proto: string) {.async.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
let
metrics = await res.metricGetter()
asProtobuf = metrics.encode()
@@ -126,21 +126,21 @@ proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} =
return MetricList.decode(protobuf).tryGet()
## We can now create our main procedure:
proc main() {.async.} =
proc main() {.async, gcsafe.} =
let rng = newRng()
proc randomMetricGenerator(): Future[MetricList] {.async.} =
proc randomMetricGenerator: Future[MetricList] {.async.} =
let metricCount = rng[].generate(uint32) mod 16
for i in 0 ..< metricCount + 1:
result.metrics.add(
Metric(name: "metric_" & $i, value: float(rng[].generate(uint16)) / 1000.0)
)
result.metrics.add(Metric(
name: "metric_" & $i,
value: float(rng[].generate(uint16)) / 1000.0
))
return result
let
metricProto1 = MetricProto.new(randomMetricGenerator)
metricProto2 = MetricProto.new(randomMetricGenerator)
switch1 = newStandardSwitch(rng = rng)
switch2 = newStandardSwitch(rng = rng)
switch1 = newStandardSwitch(rng=rng)
switch2 = newStandardSwitch(rng=rng)
switch1.mount(metricProto1)
@@ -148,17 +148,14 @@ proc main() {.async.} =
await switch2.start()
let
conn = await switch2.dial(
switch1.peerInfo.peerId, switch1.peerInfo.addrs, metricProto2.codecs
)
conn = await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, metricProto2.codecs)
metrics = await metricProto2.fetch(conn)
await conn.close()
for metric in metrics.metrics:
echo metric.name, " = ", metric.value
await allFutures(switch1.stop(), switch2.stop())
# close connections and shutdown all transports
await allFutures(switch1.stop(), switch2.stop()) # close connections and shutdown all transports
waitFor(main())

View File

@@ -7,7 +7,7 @@
## and allows to balance between latency, bandwidth usage,
## privacy and attack resistance.
##
## You'll find a good explanation of how GossipSub works
## You'll find a good explanation on how GossipSub works
## [here.](https://docs.libp2p.io/concepts/publish-subscribe/) There are a lot
## of parameters you can tweak to adjust how GossipSub behaves but here we'll
## use the sane defaults shipped with libp2p.
@@ -40,8 +40,8 @@ proc encode(m: Metric): ProtoBuffer =
proc decode(_: type Metric, buf: seq[byte]): Result[Metric, ProtoError] =
var res: Metric
let pb = initProtoBuffer(buf)
discard ?pb.getField(1, res.name)
discard ?pb.getField(2, res.value)
discard ? pb.getField(1, res.name)
discard ? pb.getField(2, res.value)
ok(res)
proc encode(m: MetricList): ProtoBuffer =
@@ -56,11 +56,11 @@ proc decode(_: type MetricList, buf: seq[byte]): Result[MetricList, ProtoError]
res: MetricList
metrics: seq[seq[byte]]
let pb = initProtoBuffer(buf)
discard ?pb.getRepeatedField(1, metrics)
discard ? pb.getRepeatedField(1, metrics)
for metric in metrics:
res.metrics &= ?Metric.decode(metric)
?pb.getRequiredField(2, res.hostname)
res.metrics &= ? Metric.decode(metric)
? pb.getRequiredField(2, res.hostname)
ok(res)
## This is exactly like the previous structure, except that we added
@@ -73,14 +73,11 @@ type Node = tuple[switch: Switch, gossip: GossipSub, hostname: string]
proc oneNode(node: Node, rng: ref HmacDrbgContext) {.async.} =
# This procedure will handle one of the node of the network
node.gossip.addValidator(
["metrics"],
node.gossip.addValidator(["metrics"],
proc(topic: string, message: Message): Future[ValidationResult] {.async.} =
let decoded = MetricList.decode(message.data)
if decoded.isErr:
return ValidationResult.Reject
if decoded.isErr: return ValidationResult.Reject
return ValidationResult.Accept
,
)
# This "validator" will attach to the `metrics` topic and make sure
# that every message in this topic is valid. This allows us to stop
@@ -90,24 +87,23 @@ proc oneNode(node: Node, rng: ref HmacDrbgContext) {.async.} =
# `John` will be responsible to log the metrics, the rest of the nodes
# will just forward them in the network
if node.hostname == "John":
node.gossip.subscribe(
"metrics",
proc(topic: string, data: seq[byte]) {.async.} =
node.gossip.subscribe("metrics",
proc (topic: string, data: seq[byte]) {.async.} =
echo MetricList.decode(data).tryGet()
,
)
else:
node.gossip.subscribe("metrics", nil)
# Create random metrics 10 times and broadcast them
for _ in 0 ..< 10:
for _ in 0..<10:
await sleepAsync(500.milliseconds)
var metricList = MetricList(hostname: node.hostname)
let metricCount = rng[].generate(uint32) mod 4
for i in 0 ..< metricCount + 1:
metricList.metrics.add(
Metric(name: "metric_" & $i, value: float(rng[].generate(uint16)) / 1000.0)
)
metricList.metrics.add(Metric(
name: "metric_" & $i,
value: float(rng[].generate(uint16)) / 1000.0
))
discard await node.gossip.publish("metrics", encode(metricList).buffer)
await node.switch.stop()
@@ -115,13 +111,13 @@ proc oneNode(node: Node, rng: ref HmacDrbgContext) {.async.} =
## For our main procedure, we'll create a few nodes, and connect them together.
## Note that they are not all interconnected, but GossipSub will take care of
## broadcasting to the full network nonetheless.
proc main() {.async.} =
proc main {.async.} =
let rng = newRng()
var nodes: seq[Node]
for hostname in ["John", "Walter", "David", "Thuy", "Amy"]:
let
switch = newStandardSwitch(rng = rng)
switch = newStandardSwitch(rng=rng)
gossip = GossipSub.init(switch = switch, triggerSelf = true)
switch.mount(gossip)
await switch.start()
@@ -131,12 +127,11 @@ proc main() {.async.} =
for index, node in nodes:
# Connect to a few neighbors
for otherNodeIdx in index - 1 .. index + 2:
if otherNodeIdx notin 0 ..< nodes.len or otherNodeIdx == index:
continue
if otherNodeIdx notin 0 ..< nodes.len or otherNodeIdx == index: continue
let otherNode = nodes[otherNodeIdx]
await node.switch.connect(
otherNode.switch.peerInfo.peerId, otherNode.switch.peerInfo.addrs
)
otherNode.switch.peerInfo.peerId,
otherNode.switch.peerInfo.addrs)
var allFuts: seq[Future[void]]
for node in nodes:

View File

@@ -2,7 +2,7 @@
##
## In the [previous tutorial](tutorial_4_gossipsub.md), we built a custom protocol using [protobuf](https://developers.google.com/protocol-buffers) and
## spread informations (some metrics) on the network using gossipsub.
## For this tutorial, on the other hand, we'll go back to a simple example
## For this tutorial, on the other hand, we'll go back on a simple example
## we'll try to discover a specific peers to greet on the network.
##
## First, as usual, we import the dependencies:
@@ -20,24 +20,22 @@ import libp2p/discovery/discoverymngr
##
## Note that other discovery methods such as [Kademlia](https://github.com/libp2p/specs/blob/master/kad-dht/README.md) or [discv5](https://github.com/ethereum/devp2p/blob/master/discv5/discv5.md) exist.
proc createSwitch(rdv: RendezVous = RendezVous.new()): Switch =
SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
.withTcpTransport()
.withYamux()
.withNoise()
.withRendezVous(rdv)
.build()
SwitchBuilder.new()
.withRng(newRng())
.withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ])
.withTcpTransport()
.withYamux()
.withNoise()
.withRendezVous(rdv)
.build()
# Create a really simple protocol to log one message received then close the stream
const DumbCodec = "/dumb/proto/1.0.0"
type DumbProto = ref object of LPProtocol
proc new(T: typedesc[DumbProto], nodeNumber: int): T =
proc handle(conn: Connection, proto: string) {.async.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024))
await conn.close()
return T.new(codecs = @[DumbCodec], handler = handle)
## ## Bootnodes
@@ -51,7 +49,7 @@ proc new(T: typedesc[DumbProto], nodeNumber: int): T =
## (rendezvous in this case) as a bootnode. For this example, we'll
## create a bootnode, and then every peer will advertise itself on the
## bootnode, and use it to find other peers
proc main() {.async.} =
proc main() {.async, gcsafe.} =
let bootNode = createSwitch()
await bootNode.start()
@@ -60,7 +58,7 @@ proc main() {.async.} =
switches: seq[Switch] = @[]
discManagers: seq[DiscoveryManager] = @[]
for i in 0 .. 5:
for i in 0..5:
let rdv = RendezVous.new()
# Create a remote future to await at the end of the program
let switch = createSwitch(rdv)
@@ -95,7 +93,7 @@ proc main() {.async.} =
# Use the discovery manager to find peers on the OddClub topic to greet them
let queryOddClub = dm.request(RdvNamespace("OddClub"))
for _ in 0 .. 2:
for _ in 0..2:
let
# getPeer give you a PeerAttribute containing informations about the peer.
res = await queryOddClub.getPeer()
@@ -111,7 +109,7 @@ proc main() {.async.} =
# Maybe it was because he wanted to join the EvenGang
let queryEvenGang = dm.request(RdvNamespace("EvenGang"))
for _ in 0 .. 2:
for _ in 0..2:
let
res = await queryEvenGang.getPeer()
conn = await newcomer.dial(res[PeerId], res.getAll(MultiAddress), DumbCodec)

View File

@@ -51,7 +51,7 @@ proc new(_: type[Game]): Game =
tickTime: -3.0, # 3 seconds of "warm-up" time
localPlayer: Player(x: 4, y: 16, currentDir: 3, nextDir: 3, color: 8),
remotePlayer: Player(x: 27, y: 16, currentDir: 1, nextDir: 1, color: 12),
peerFound: newFuture[Connection](),
peerFound: newFuture[Connection]()
)
for pos in 0 .. result.gameMap.high:
if pos mod mapSize in [0, mapSize - 1] or pos div mapSize in [0, mapSize - 1]:
@@ -81,8 +81,7 @@ proc update(g: Game, dt: float32) =
# This is a hacky way to make this happen
waitFor(sleepAsync(1.milliseconds))
# Don't do anything if we are still waiting for an opponent
if not (g.peerFound.finished()) or isNil(g.tickFinished):
return
if not(g.peerFound.finished()) or isNil(g.tickFinished): return
g.tickTime += dt
# Update the wanted direction, making sure we can't go backward
@@ -99,8 +98,7 @@ proc tick(g: Game, p: Player) =
# Move player and check if he lost
p.x += directions[p.currentDir][1]
p.y += directions[p.currentDir][2]
if g.gameMap[p.y * mapSize + p.x] != 0:
p.lost = true
if g.gameMap[p.y * mapSize + p.x] != 0: p.lost = true
g.gameMap[p.y * mapSize + p.x] = p.color
proc mainLoop(g: Game, peer: Connection) {.async.} =
@@ -125,23 +123,16 @@ proc draw(g: Game) =
for pos, color in g.gameMap:
setColor(color)
boxFill(pos mod 32 * 4, pos div 32 * 4, 4, 4)
let text =
if not (g.peerFound.finished()):
"Matchmaking.."
elif g.tickTime < -1.5:
"Welcome to Etron"
elif g.tickTime < 0.0:
"- " & $(int(abs(g.tickTime) / 0.5) + 1) & " -"
elif g.remotePlayer.lost and g.localPlayer.lost:
"DEUCE"
elif g.localPlayer.lost:
"YOU LOOSE"
elif g.remotePlayer.lost:
"YOU WON"
else:
""
let text = if not(g.peerFound.finished()): "Matchmaking.."
elif g.tickTime < -1.5: "Welcome to Etron"
elif g.tickTime < 0.0: "- " & $(int(abs(g.tickTime) / 0.5) + 1) & " -"
elif g.remotePlayer.lost and g.localPlayer.lost: "DEUCE"
elif g.localPlayer.lost: "YOU LOOSE"
elif g.remotePlayer.lost: "YOU WON"
else: ""
printc(text, screenWidth div 2, screenHeight div 2)
## ## Matchmaking
## To find an opponent, we will broadcast our address on a
## GossipSub topic, and wait for someone to connect to us.
@@ -152,9 +143,8 @@ proc draw(g: Game) =
## peer know that we are available, check that he is also available,
## and launch the game.
proc new(T: typedesc[GameProto], g: Game): T =
proc handle(conn: Connection, proto: string) {.async.} =
defer:
await conn.closeWithEof()
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
defer: await conn.closeWithEof()
if g.peerFound.finished or g.hasCandidate:
await conn.close()
return
@@ -167,7 +157,6 @@ proc new(T: typedesc[GameProto], g: Game): T =
# The handler of a protocol must wait for the stream to
# be finished before returning
await conn.join()
return T.new(codecs = @["/tron/1.0.0"], handler = handle)
proc networking(g: Game) {.async.} =
@@ -175,10 +164,9 @@ proc networking(g: Game) {.async.} =
# the Discovery examples combined
let
rdv = RendezVous.new()
switch = SwitchBuilder
.new()
switch = SwitchBuilder.new()
.withRng(newRng())
.withAddresses(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()])
.withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ])
.withTcpTransport()
.withYamux()
.withNoise()
@@ -186,7 +174,9 @@ proc networking(g: Game) {.async.} =
.build()
dm = DiscoveryManager()
gameProto = GameProto.new(g)
gossip = GossipSub.init(switch = switch, triggerSelf = false)
gossip = GossipSub.init(
switch = switch,
triggerSelf = false)
dm.add(RendezVousInterface.new(rdv))
switch.mount(gossip)
@@ -194,11 +184,10 @@ proc networking(g: Game) {.async.} =
gossip.subscribe(
"/tron/matchmaking",
proc(topic: string, data: seq[byte]) {.async.} =
proc (topic: string, data: seq[byte]) {.async.} =
# If we are still looking for an opponent,
# try to match anyone broadcasting its address
if g.peerFound.finished or g.hasCandidate:
return
# try to match anyone broadcasting it's address
if g.peerFound.finished or g.hasCandidate: return
g.hasCandidate = true
try:
@@ -215,12 +204,10 @@ proc networking(g: Game) {.async.} =
swap(g.localPlayer, g.remotePlayer)
except CatchableError as exc:
discard
,
)
await switch.start()
defer:
await switch.stop()
defer: await switch.stop()
# As explained in the last tutorial, we need a bootnode to be able
# to find peers. We could use any libp2p running rendezvous (or any
@@ -256,8 +243,7 @@ proc networking(g: Game) {.async.} =
# We now wait for someone to connect to us (or for us to connect to someone)
let peerConn = await g.peerFound
defer:
await peerConn.closeWithEof()
defer: await peerConn.closeWithEof()
await g.mainLoop(peerConn)
@@ -266,17 +252,7 @@ let
netFut = networking(game)
nico.init("Status", "Tron")
nico.createWindow("Tron", mapSize * 4, mapSize * 4, 4, false)
nico.run(
proc() =
discard
,
proc(dt: float32) =
game.update(dt)
,
proc() =
game.draw()
,
)
nico.run(proc = discard, proc(dt: float32) = game.update(dt), proc = game.draw())
waitFor(netFut.cancelAndWait())
## And that's it! If you want to run this code locally, the simplest way is to use the

View File

@@ -21,8 +21,7 @@ when defined(nimdoc):
## that can help you get started.
# Import stuff for doc
import
libp2p/[
import libp2p/[
protobuf/minprotobuf,
switch,
stream/lpstream,
@@ -34,39 +33,40 @@ when defined(nimdoc):
peerid,
peerinfo,
peerstore,
multiaddress,
]
multiaddress]
proc dummyPrivateProc*() =
## A private proc example
discard
else:
import
libp2p/[
protobuf/minprotobuf,
muxers/muxer,
muxers/mplex/mplex,
stream/lpstream,
stream/bufferstream,
stream/connection,
transports/transport,
transports/tcptransport,
protocols/secure/noise,
cid,
multihash,
multicodec,
errors,
switch,
peerid,
peerinfo,
multiaddress,
builders,
crypto/crypto,
protocols/pubsub,
]
libp2p/[protobuf/minprotobuf,
muxers/muxer,
muxers/mplex/mplex,
stream/lpstream,
stream/bufferstream,
stream/connection,
transports/transport,
transports/tcptransport,
transports/wstransport,
protocols/secure/noise,
protocols/ping,
cid,
multihash,
multibase,
multicodec,
errors,
switch,
peerid,
peerinfo,
multiaddress,
builders,
crypto/crypto,
protocols/pubsub]
export
minprotobuf, switch, peerid, peerinfo, connection, multiaddress, crypto, lpstream,
bufferstream, muxer, mplex, transport, tcptransport, noise, errors, cid, multihash,
minprotobuf, switch, peerid, peerinfo,
connection, multiaddress, crypto, lpstream,
bufferstream, muxer, mplex, transport,
tcptransport, noise, errors, cid, multihash,
multicodec, builders, pubsub

View File

@@ -1,33 +1,30 @@
mode = ScriptMode.Verbose
packageName = "libp2p"
version = "1.4.0"
author = "Status Research & Development GmbH"
description = "LibP2P implementation"
license = "MIT"
skipDirs = @["tests", "examples", "Nim", "tools", "scripts", "docs"]
packageName = "libp2p"
version = "1.0.0"
author = "Status Research & Development GmbH"
description = "LibP2P implementation"
license = "MIT"
skipDirs = @["tests", "examples", "Nim", "tools", "scripts", "docs"]
requires "nim >= 1.6.0",
"nimcrypto >= 0.4.1", "dnsclient >= 0.3.0 & < 0.4.0", "bearssl >= 0.1.4",
"chronicles >= 0.10.2", "chronos >= 4.0.0", "metrics", "secp256k1", "stew#head",
"websock", "unittest2"
requires "nim >= 1.2.0",
"nimcrypto >= 0.4.1",
"dnsclient >= 0.3.0 & < 0.4.0",
"bearssl >= 0.1.4",
"chronicles >= 0.10.2",
"chronos >= 3.0.6",
"metrics",
"secp256k1",
"stew#head",
"websock",
"unittest2 >= 0.0.5 & < 0.1.0"
let nimc = getEnv("NIMC", "nim") # Which nim compiler to use
let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js)
let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler
let verbose = getEnv("V", "") notin ["", "0"]
let cfg =
" --styleCheck:usages --styleCheck:error" &
(if verbose: "" else: " --verbosity:0 --hints:off") &
" --skipParentCfg --skipUserCfg -f" & " --threads:on --opt:speed"
import hashes, strutils
proc runTest(
filename: string, verify: bool = true, sign: bool = true, moreoptions: string = ""
) =
var excstr = nimc & " " & lang & " -d:debug " & cfg & " " & flags
import hashes
proc runTest(filename: string, verify: bool = true, sign: bool = true,
moreoptions: string = "") =
var excstr = "nim c --skipParentCfg --opt:speed -d:debug "
excstr.add(" " & getEnv("NIMFLAGS") & " ")
excstr.add(" --verbosity:0 --hints:off ")
excstr.add(" -d:libp2p_pubsub_sign=" & $sign)
excstr.add(" -d:libp2p_pubsub_verify=" & $verify)
excstr.add(" " & moreoptions & " ")
@@ -37,7 +34,7 @@ proc runTest(
rmFile "tests/" & filename.toExe
proc buildSample(filename: string, run = false, extraFlags = "") =
var excstr = nimc & " " & lang & " " & cfg & " " & flags & " -p:. " & extraFlags
var excstr = "nim c --opt:speed --threads:on -d:debug --verbosity:0 --hints:off -p:. " & extraFlags
excstr.add(" examples/" & filename)
exec excstr
if run:
@@ -45,8 +42,7 @@ proc buildSample(filename: string, run = false, extraFlags = "") =
rmFile "examples/" & filename.toExe
proc tutorialToMd(filename: string) =
let markdown = gorge "cat " & filename & " | " & nimc & " " & lang &
" -r --verbosity:0 --hints:off tools/markdown_builder.nim "
let markdown = gorge "cat " & filename & " | nim c -r --verbosity:0 --hints:off tools/markdown_builder.nim "
writeFile(filename.replace(".nim", ".md"), markdown)
task testnative, "Runs libp2p native tests":
@@ -59,37 +55,24 @@ task testinterop, "Runs interop tests":
runTest("testinterop")
task testpubsub, "Runs pubsub tests":
runTest(
"pubsub/testgossipinternal",
sign = false,
verify = false,
moreoptions = "-d:pubsub_internal_testing",
)
runTest("pubsub/testgossipinternal", sign = false, verify = false, moreoptions = "-d:pubsub_internal_testing")
runTest("pubsub/testpubsub")
runTest("pubsub/testpubsub", sign = false, verify = false)
runTest(
"pubsub/testpubsub",
sign = false,
verify = false,
moreoptions = "-d:libp2p_pubsub_anonymize=true",
)
runTest("pubsub/testpubsub", sign = false, verify = false, moreoptions = "-d:libp2p_pubsub_anonymize=true")
task testpubsub_slim, "Runs pubsub tests":
runTest(
"pubsub/testgossipinternal",
sign = false,
verify = false,
moreoptions = "-d:pubsub_internal_testing",
)
runTest("pubsub/testgossipinternal", sign = false, verify = false, moreoptions = "-d:pubsub_internal_testing")
runTest("pubsub/testpubsub")
task testfilter, "Run PKI filter test":
runTest("testpkifilter", moreoptions = "-d:libp2p_pki_schemes=\"secp256k1\"")
runTest("testpkifilter", moreoptions = "-d:libp2p_pki_schemes=\"secp256k1;ed25519\"")
runTest(
"testpkifilter", moreoptions = "-d:libp2p_pki_schemes=\"secp256k1;ed25519;ecnist\""
)
runTest("testpkifilter", moreoptions = "-d:libp2p_pki_schemes=")
runTest("testpkifilter",
moreoptions = "-d:libp2p_pki_schemes=\"secp256k1\"")
runTest("testpkifilter",
moreoptions = "-d:libp2p_pki_schemes=\"secp256k1;ed25519\"")
runTest("testpkifilter",
moreoptions = "-d:libp2p_pki_schemes=\"secp256k1;ed25519;ecnist\"")
runTest("testpkifilter",
moreoptions = "-d:libp2p_pki_schemes=")
task test, "Runs the test suite":
exec "nimble testnative"
@@ -121,13 +104,15 @@ task examples_build, "Build the samples":
buildSample("circuitrelay", true)
buildSample("tutorial_1_connect", true)
buildSample("tutorial_2_customproto", true)
buildSample("tutorial_3_protobuf", true)
buildSample("tutorial_4_gossipsub", true)
buildSample("tutorial_5_discovery", true)
exec "nimble install -y nimpng@#HEAD"
# this is to fix broken build on 1.7.3, remove it when nimpng version 0.3.2 or later is released
exec "nimble install -y nico"
buildSample("tutorial_6_game", false, "--styleCheck:off")
if (NimMajor, NimMinor) > (1, 2):
# These tutorials relies on post 1.4 exception tracking
buildSample("tutorial_3_protobuf", true)
buildSample("tutorial_4_gossipsub", true)
buildSample("tutorial_5_discovery", true)
# Nico doesn't work in 1.2
exec "nimble install -y nimpng@#HEAD" # this is to fix broken build on 1.7.3, remove it when nimpng version 0.3.2 or later is released
exec "nimble install -y nico"
buildSample("tutorial_6_game", false, "--styleCheck:off")
# pin system
# while nimble lockfile
@@ -138,14 +123,12 @@ task pin, "Create a lockfile":
# pinner.nim was originally here
# but you can't read output from
# a command in a nimscript
exec nimc & " c -r tools/pinner.nim"
exec "nim c -r tools/pinner.nim"
import sequtils
import os
task install_pinned, "Reads the lockfile":
let toInstall = readFile(PinFile).splitWhitespace().mapIt(
(it.split(";", 1)[0], it.split(";", 1)[1])
)
let toInstall = readFile(PinFile).splitWhitespace().mapIt((it.split(";", 1)[0], it.split(";", 1)[1]))
# [('packageName', 'packageFullUri')]
rmDir("nimbledeps")
@@ -155,7 +138,8 @@ task install_pinned, "Reads the lockfile":
# Remove the automatically installed deps
# (inefficient you say?)
let nimblePkgs =
if system.dirExists("nimbledeps/pkgs"): "nimbledeps/pkgs" else: "nimbledeps/pkgs2"
if system.dirExists("nimbledeps/pkgs"): "nimbledeps/pkgs"
else: "nimbledeps/pkgs2"
for dependency in listDirs(nimblePkgs):
let
fileName = dependency.extractFilename
@@ -163,12 +147,14 @@ task install_pinned, "Reads the lockfile":
packageName = fileName.split('-')[0]
if toInstall.anyIt(
it[0] == packageName and (
it[1].split('#')[^1] in fileContent or # nimble for nim 2.X
fileName.endsWith(it[1].split('#')[^1]) # nimble for nim 1.X
)
) == false or fileName.split('-')[^1].len < 20: # safegard for nimble for nim 1.X
rmDir(dependency)
it[0] == packageName and
(
it[1].split('#')[^1] in fileContent or # nimble for nim 2.X
fileName.endsWith(it[1].split('#')[^1]) # nimble for nim 1.X
)
) == false or
fileName.split('-')[^1].len < 20: # safegard for nimble for nim 1.X
rmDir(dependency)
task unpin, "Restore global package use":
rmDir("nimbledeps")

View File

@@ -9,39 +9,38 @@
## This module contains a Switch Building helper.
runnableExamples:
let switch = SwitchBuilder.new().withRng(rng).withAddresses(multiaddress)
# etc
.build()
let switch =
SwitchBuilder.new()
.withRng(rng)
.withAddresses(multiaddress)
# etc
.build()
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import options, tables, chronos, chronicles, sequtils
import
switch,
peerid,
peerinfo,
stream/connection,
multiaddress,
crypto/crypto,
transports/[transport, tcptransport],
options, tables, chronos, chronicles, sequtils,
switch, peerid, peerinfo, stream/connection, multiaddress,
crypto/crypto, transports/[transport, tcptransport],
muxers/[muxer, mplex/mplex, yamux/yamux],
protocols/[identify, secure/secure, secure/noise, rendezvous],
protocols/connectivity/[autonat/server, relay/relay, relay/client, relay/rtransport],
connmanager,
upgrademngrs/muxedupgrade,
observedaddrmanager,
connmanager, upgrademngrs/muxedupgrade,
nameresolving/nameresolver,
errors,
utility
import services/wildcardresolverservice
errors, utility
export switch, peerid, peerinfo, connection, multiaddress, crypto, errors
export
switch, peerid, peerinfo, connection, multiaddress, crypto, errors
type
TransportProvider* {.public.} = proc(upgr: Upgrade): Transport {.gcsafe, raises: [].}
TransportProvider* {.public.} = proc(upgr: Upgrade): Transport {.gcsafe, raises: [Defect].}
SecureProtocol* {.pure.} = enum
Noise
Noise,
Secio {.deprecated.}
SwitchBuilder* = ref object
privKey: Option[PrivateKey]
@@ -58,19 +57,18 @@ type
protoVersion: string
agentVersion: string
nameResolver: NameResolver
peerStoreCapacity: Opt[int]
peerStoreCapacity: Option[int]
autonat: bool
circuitRelay: Relay
rdv: RendezVous
services: seq[Service]
observedAddrManager: ObservedAddrManager
enableWildcardResolver: bool
proc new*(T: type[SwitchBuilder]): T {.public.} =
## Creates a SwitchBuilder
let address =
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("Should initialize to default")
let address = MultiAddress
.init("/ip4/127.0.0.1/tcp/0")
.expect("Should initialize to default")
SwitchBuilder(
privKey: none(PrivateKey),
@@ -81,59 +79,53 @@ proc new*(T: type[SwitchBuilder]): T {.public.} =
maxOut: -1,
maxConnsPerPeer: MaxConnectionsPerPeer,
protoVersion: ProtoVersion,
agentVersion: AgentVersion,
enableWildcardResolver: true,
)
agentVersion: AgentVersion)
proc withPrivateKey*(
b: SwitchBuilder, privateKey: PrivateKey
): SwitchBuilder {.public.} =
proc withPrivateKey*(b: SwitchBuilder, privateKey: PrivateKey): SwitchBuilder {.public.} =
## Set the private key of the switch. Will be used to
## generate a PeerId
b.privKey = some(privateKey)
b
proc withAddresses*(
b: SwitchBuilder, addresses: seq[MultiAddress], enableWildcardResolver: bool = true
): SwitchBuilder {.public.} =
## | Set the listening addresses of the switch
## | Calling it multiple time will override the value
b.addresses = addresses
b.enableWildcardResolver = enableWildcardResolver
b
proc withAddress*(
b: SwitchBuilder, address: MultiAddress, enableWildcardResolver: bool = true
): SwitchBuilder {.public.} =
proc withAddress*(b: SwitchBuilder, address: MultiAddress): SwitchBuilder {.public.} =
## | Set the listening address of the switch
## | Calling it multiple time will override the value
b.withAddresses(@[address], enableWildcardResolver)
b.addresses = @[address]
b
proc withAddresses*(b: SwitchBuilder, addresses: seq[MultiAddress]): SwitchBuilder {.public.} =
## | Set the listening addresses of the switch
## | Calling it multiple time will override the value
b.addresses = addresses
b
proc withSignedPeerRecord*(b: SwitchBuilder, sendIt = true): SwitchBuilder {.public.} =
b.sendSignedPeerRecord = sendIt
b
proc withMplex*(
b: SwitchBuilder, inTimeout = 5.minutes, outTimeout = 5.minutes, maxChannCount = 200
): SwitchBuilder {.public.} =
b: SwitchBuilder,
inTimeout = 5.minutes,
outTimeout = 5.minutes,
maxChannCount = 200): SwitchBuilder {.public.} =
## | Uses `Mplex <https://docs.libp2p.io/concepts/stream-multiplexing/#mplex>`_ as a multiplexer
## | `Timeout` is the duration after which a inactive connection will be closed
proc newMuxer(conn: Connection): Muxer =
Mplex.new(conn, inTimeout, outTimeout, maxChannCount)
Mplex.new(
conn,
inTimeout,
outTimeout,
maxChannCount)
assert b.muxers.countIt(it.codec == MplexCodec) == 0, "Mplex build multiple times"
b.muxers.add(MuxerProvider.new(newMuxer, MplexCodec))
b
proc withYamux*(
b: SwitchBuilder,
windowSize: int = YamuxDefaultWindowSize,
inTimeout: Duration = 5.minutes,
outTimeout: Duration = 5.minutes,
): SwitchBuilder =
proc newMuxer(conn: Connection): Muxer =
Yamux.new(conn, windowSize, inTimeout = inTimeout, outTimeout = outTimeout)
proc withYamux*(b: SwitchBuilder): SwitchBuilder =
proc newMuxer(conn: Connection): Muxer = Yamux.new(conn)
assert b.muxers.countIt(it.codec == YamuxCodec) == 0, "Yamux build multiple times"
b.muxers.add(MuxerProvider.new(newMuxer, YamuxCodec))
@@ -143,36 +135,24 @@ proc withNoise*(b: SwitchBuilder): SwitchBuilder {.public.} =
b.secureManagers.add(SecureProtocol.Noise)
b
proc withTransport*(
b: SwitchBuilder, prov: TransportProvider
): SwitchBuilder {.public.} =
proc withTransport*(b: SwitchBuilder, prov: TransportProvider): SwitchBuilder {.public.} =
## Use a custom transport
runnableExamples:
let switch = SwitchBuilder
.new()
.withTransport(
proc(upgr: Upgrade): Transport =
TcpTransport.new(flags, upgr)
)
let switch =
SwitchBuilder.new()
.withTransport(proc(upgr: Upgrade): Transport = TcpTransport.new(flags, upgr))
.build()
b.transports.add(prov)
b
proc withTcpTransport*(
b: SwitchBuilder, flags: set[ServerFlags] = {}
): SwitchBuilder {.public.} =
b.withTransport(
proc(upgr: Upgrade): Transport =
TcpTransport.new(flags, upgr)
)
proc withTcpTransport*(b: SwitchBuilder, flags: set[ServerFlags] = {}): SwitchBuilder {.public.} =
b.withTransport(proc(upgr: Upgrade): Transport = TcpTransport.new(flags, upgr))
proc withRng*(b: SwitchBuilder, rng: ref HmacDrbgContext): SwitchBuilder {.public.} =
b.rng = rng
b
proc withMaxConnections*(
b: SwitchBuilder, maxConnections: int
): SwitchBuilder {.public.} =
proc withMaxConnections*(b: SwitchBuilder, maxConnections: int): SwitchBuilder {.public.} =
## Maximum concurrent connections of the switch. You should either use this, or
## `withMaxIn <#withMaxIn,SwitchBuilder,int>`_ & `withMaxOut<#withMaxOut,SwitchBuilder,int>`_
b.maxConnections = maxConnections
@@ -188,31 +168,23 @@ proc withMaxOut*(b: SwitchBuilder, maxOut: int): SwitchBuilder {.public.} =
b.maxOut = maxOut
b
proc withMaxConnsPerPeer*(
b: SwitchBuilder, maxConnsPerPeer: int
): SwitchBuilder {.public.} =
proc withMaxConnsPerPeer*(b: SwitchBuilder, maxConnsPerPeer: int): SwitchBuilder {.public.} =
b.maxConnsPerPeer = maxConnsPerPeer
b
proc withPeerStore*(b: SwitchBuilder, capacity: int): SwitchBuilder {.public.} =
b.peerStoreCapacity = Opt.some(capacity)
b.peerStoreCapacity = some(capacity)
b
proc withProtoVersion*(
b: SwitchBuilder, protoVersion: string
): SwitchBuilder {.public.} =
proc withProtoVersion*(b: SwitchBuilder, protoVersion: string): SwitchBuilder {.public.} =
b.protoVersion = protoVersion
b
proc withAgentVersion*(
b: SwitchBuilder, agentVersion: string
): SwitchBuilder {.public.} =
proc withAgentVersion*(b: SwitchBuilder, agentVersion: string): SwitchBuilder {.public.} =
b.agentVersion = agentVersion
b
proc withNameResolver*(
b: SwitchBuilder, nameResolver: NameResolver
): SwitchBuilder {.public.} =
proc withNameResolver*(b: SwitchBuilder, nameResolver: NameResolver): SwitchBuilder {.public.} =
b.nameResolver = nameResolver
b
@@ -224,9 +196,7 @@ proc withCircuitRelay*(b: SwitchBuilder, r: Relay = Relay.new()): SwitchBuilder
b.circuitRelay = r
b
proc withRendezVous*(
b: SwitchBuilder, rdv: RendezVous = RendezVous.new()
): SwitchBuilder =
proc withRendezVous*(b: SwitchBuilder, rdv: RendezVous = RendezVous.new()): SwitchBuilder =
b.rdv = rdv
b
@@ -234,44 +204,40 @@ proc withServices*(b: SwitchBuilder, services: seq[Service]): SwitchBuilder =
b.services = services
b
proc withObservedAddrManager*(
b: SwitchBuilder, observedAddrManager: ObservedAddrManager
): SwitchBuilder =
b.observedAddrManager = observedAddrManager
b
proc build*(b: SwitchBuilder): Switch
{.raises: [Defect, LPError], public.} =
proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
if b.rng == nil: # newRng could fail
raise newException(Defect, "Cannot initialize RNG")
let pkRes = PrivateKey.random(b.rng[])
let seckey = b.privKey.get(otherwise = pkRes.expect("Expected default Private Key"))
let
seckey = b.privKey.get(otherwise = pkRes.expect("Expected default Private Key"))
var secureManagerInstances: seq[Secure]
var
secureManagerInstances: seq[Secure]
if SecureProtocol.Noise in b.secureManagers:
secureManagerInstances.add(Noise.new(b.rng, seckey).Secure)
let peerInfo = PeerInfo.new(
seckey, b.addresses, protoVersion = b.protoVersion, agentVersion = b.agentVersion
)
let identify =
if b.observedAddrManager != nil:
Identify.new(peerInfo, b.sendSignedPeerRecord, b.observedAddrManager)
else:
Identify.new(peerInfo, b.sendSignedPeerRecord)
let
peerInfo = PeerInfo.new(
seckey,
b.addresses,
protoVersion = b.protoVersion,
agentVersion = b.agentVersion)
let
connManager =
ConnManager.new(b.maxConnsPerPeer, b.maxConnections, b.maxIn, b.maxOut)
identify = Identify.new(peerInfo, b.sendSignedPeerRecord)
connManager = ConnManager.new(b.maxConnsPerPeer, b.maxConnections, b.maxIn, b.maxOut)
ms = MultistreamSelect.new()
muxedUpgrade = MuxedUpgrade.new(b.muxers, secureManagerInstances, ms)
muxedUpgrade = MuxedUpgrade.new(identify, b.muxers, secureManagerInstances, connManager, ms)
let transports = block:
var transports: seq[Transport]
for tProvider in b.transports:
transports.add(tProvider(muxedUpgrade))
transports
let
transports = block:
var transports: seq[Transport]
for tProvider in b.transports:
transports.add(tProvider(muxedUpgrade))
transports
if b.secureManagers.len == 0:
b.secureManagers &= SecureProtocol.Noise
@@ -279,27 +245,22 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
if isNil(b.rng):
b.rng = newRng()
let peerStore = block:
b.peerStoreCapacity.withValue(capacity):
PeerStore.new(identify, capacity)
let peerStore =
if isSome(b.peerStoreCapacity):
PeerStore.new(b.peerStoreCapacity.get())
else:
PeerStore.new(identify)
if b.enableWildcardResolver:
b.services.insert(WildcardAddressResolverService.new(), 0)
PeerStore.new()
let switch = newSwitch(
peerInfo = peerInfo,
transports = transports,
identity = identify,
secureManagers = secureManagerInstances,
connManager = connManager,
ms = ms,
nameResolver = b.nameResolver,
peerStore = peerStore,
services = b.services,
)
switch.mount(identify)
services = b.services)
if b.autonat:
let autonat = Autonat.new(switch)
@@ -318,28 +279,29 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
return switch
proc newStandardSwitch*(
privKey = none(PrivateKey),
addrs: MultiAddress | seq[MultiAddress] =
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
transportFlags: set[ServerFlags] = {},
rng = newRng(),
inTimeout: Duration = 5.minutes,
outTimeout: Duration = 5.minutes,
maxConnections = MaxConnections,
maxIn = -1,
maxOut = -1,
maxConnsPerPeer = MaxConnectionsPerPeer,
nameResolver: NameResolver = nil,
sendSignedPeerRecord = false,
peerStoreCapacity = 1000,
): Switch {.raises: [LPError], public.} =
privKey = none(PrivateKey),
addrs: MultiAddress | seq[MultiAddress] = MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet(),
secureManagers: openArray[SecureProtocol] = [
SecureProtocol.Noise,
],
transportFlags: set[ServerFlags] = {},
rng = newRng(),
inTimeout: Duration = 5.minutes,
outTimeout: Duration = 5.minutes,
maxConnections = MaxConnections,
maxIn = -1,
maxOut = -1,
maxConnsPerPeer = MaxConnectionsPerPeer,
nameResolver: NameResolver = nil,
sendSignedPeerRecord = false,
peerStoreCapacity = 1000): Switch
{.raises: [Defect, LPError], public.} =
## Helper for common switch configurations.
let addrs =
when addrs is MultiAddress:
@[addrs]
else:
addrs
if SecureProtocol.Secio in secureManagers:
quit("Secio is deprecated!") # use of secio is unsafe
let addrs = when addrs is MultiAddress: @[addrs] else: addrs
var b = SwitchBuilder
.new()
.withAddresses(addrs)
@@ -349,13 +311,13 @@ proc newStandardSwitch*(
.withMaxIn(maxIn)
.withMaxOut(maxOut)
.withMaxConnsPerPeer(maxConnsPerPeer)
.withPeerStore(capacity = peerStoreCapacity)
.withPeerStore(capacity=peerStoreCapacity)
.withMplex(inTimeout, outTimeout)
.withTcpTransport(transportFlags)
.withNameResolver(nameResolver)
.withNoise()
privKey.withValue(pkey):
b = b.withPrivateKey(pkey)
if privKey.isSome():
b = b.withPrivateKey(privKey.get())
b.build()

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -9,7 +9,10 @@
## This module implementes CID (Content IDentifier).
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import tables, hashes
import multibase, multicodec, multihash, vbuffer, varint
@@ -19,16 +22,10 @@ export results
type
CidError* {.pure.} = enum
Error
Incorrect
Unsupported
Overrun
Error, Incorrect, Unsupported, Overrun
CidVersion* = enum
CIDvIncorrect
CIDv0
CIDv1
CIDvReserved
CIDvIncorrect, CIDv0, CIDv1, CIDvReserved
Cid* = object
cidver*: CidVersion
@@ -36,51 +33,54 @@ type
hpos*: int
data*: VBuffer
const ContentIdsList = [
multiCodec("raw"),
multiCodec("dag-pb"),
multiCodec("dag-cbor"),
multiCodec("dag-json"),
multiCodec("git-raw"),
multiCodec("eth-block"),
multiCodec("eth-block-list"),
multiCodec("eth-tx-trie"),
multiCodec("eth-tx"),
multiCodec("eth-tx-receipt-trie"),
multiCodec("eth-tx-receipt"),
multiCodec("eth-state-trie"),
multiCodec("eth-account-snapshot"),
multiCodec("eth-storage-trie"),
multiCodec("bitcoin-block"),
multiCodec("bitcoin-tx"),
multiCodec("zcash-block"),
multiCodec("zcash-tx"),
multiCodec("stellar-block"),
multiCodec("stellar-tx"),
multiCodec("decred-block"),
multiCodec("decred-tx"),
multiCodec("dash-block"),
multiCodec("dash-tx"),
multiCodec("torrent-info"),
multiCodec("torrent-file"),
multiCodec("ed25519-pub"),
]
const
ContentIdsList = [
multiCodec("raw"),
multiCodec("dag-pb"),
multiCodec("dag-cbor"),
multiCodec("dag-json"),
multiCodec("git-raw"),
multiCodec("eth-block"),
multiCodec("eth-block-list"),
multiCodec("eth-tx-trie"),
multiCodec("eth-tx"),
multiCodec("eth-tx-receipt-trie"),
multiCodec("eth-tx-receipt"),
multiCodec("eth-state-trie"),
multiCodec("eth-account-snapshot"),
multiCodec("eth-storage-trie"),
multiCodec("bitcoin-block"),
multiCodec("bitcoin-tx"),
multiCodec("zcash-block"),
multiCodec("zcash-tx"),
multiCodec("stellar-block"),
multiCodec("stellar-tx"),
multiCodec("decred-block"),
multiCodec("decred-tx"),
multiCodec("dash-block"),
multiCodec("dash-tx"),
multiCodec("torrent-info"),
multiCodec("torrent-file"),
multiCodec("ed25519-pub")
]
proc initCidCodeTable(): Table[int, MultiCodec] {.compileTime.} =
for item in ContentIdsList:
result[int(item)] = item
const CodeContentIds = initCidCodeTable()
const
CodeContentIds = initCidCodeTable()
template orError*(exp: untyped, err: untyped): untyped =
exp.mapErr do(_: auto) -> auto:
err
(exp.mapErr do (_: auto) -> auto: err)
proc decode(data: openArray[byte]): Result[Cid, CidError] =
if len(data) == 34 and data[0] == 0x12'u8 and data[1] == 0x20'u8:
ok(
Cid(cidver: CIDv0, mcodec: multiCodec("dag-pb"), hpos: 0, data: initVBuffer(data))
)
ok(Cid(
cidver: CIDv0,
mcodec: multiCodec("dag-pb"),
hpos: 0,
data: initVBuffer(data)))
else:
var version, codec: uint64
var res, offset: int
@@ -101,18 +101,21 @@ proc decode(data: openArray[byte]): Result[Cid, CidError] =
err(CidError.Incorrect)
else:
offset += res
var mcodec =
CodeContentIds.getOrDefault(cast[int](codec), InvalidMultiCodec)
var mcodec = CodeContentIds.getOrDefault(cast[int](codec),
InvalidMultiCodec)
if mcodec == InvalidMultiCodec:
err(CidError.Incorrect)
else:
if not MultiHash.validate(
vb.buffer.toOpenArray(vb.offset, vb.buffer.high)
):
if not MultiHash.validate(vb.buffer.toOpenArray(vb.offset,
vb.buffer.high)):
err(CidError.Incorrect)
else:
vb.finish()
ok(Cid(cidver: CIDv1, mcodec: mcodec, hpos: offset, data: vb))
ok(Cid(
cidver: CIDv1,
mcodec: mcodec,
hpos: offset,
data: vb))
proc decode(data: openArray[char]): Result[Cid, CidError] =
var buffer: seq[byte]
@@ -172,9 +175,7 @@ proc mhash*(cid: Cid): Result[MultiHash, CidError] =
if cid.cidver notin {CIDv0, CIDv1}:
err(CidError.Incorrect)
else:
MultiHash.init(cid.data.buffer.toOpenArray(cid.hpos, cid.data.high)).orError(
CidError.Incorrect
)
MultiHash.init(cid.data.buffer.toOpenArray(cid.hpos, cid.data.high)).orError(CidError.Incorrect)
proc contentType*(cid: Cid): Result[MultiCodec, CidError] =
## Returns content type part of CID
@@ -187,15 +188,12 @@ proc version*(cid: Cid): CidVersion =
## Returns CID version
result = cid.cidver
proc init*[T: char | byte](
ctype: typedesc[Cid], data: openArray[T]
): Result[Cid, CidError] =
proc init*[T: char|byte](ctype: typedesc[Cid], data: openArray[T]): Result[Cid, CidError] =
## Create new content identifier using array of bytes or string ``data``.
decode(data)
proc init*(
ctype: typedesc[Cid], version: CidVersion, content: MultiCodec, hash: MultiHash
): Result[Cid, CidError] =
proc init*(ctype: typedesc[Cid], version: CidVersion, content: MultiCodec,
hash: MultiHash): Result[Cid, CidError] =
## Create new content identifier using content type ``content`` and
## MultiHash ``hash`` using version ``version``.
##
@@ -218,7 +216,8 @@ proc init*(
res.data.finish()
return ok(res)
elif version == CIDv1:
let mcodec = CodeContentIds.getOrDefault(cast[int](content), InvalidMultiCodec)
let mcodec = CodeContentIds.getOrDefault(cast[int](content),
InvalidMultiCodec)
if mcodec == InvalidMultiCodec:
return err(CidError.Incorrect)
res.mcodec = mcodec
@@ -237,9 +236,11 @@ proc `==`*(a: Cid, b: Cid): bool =
## are equal, ``false`` otherwise.
if a.mcodec == b.mcodec:
var ah, bh: MultiHash
if MultiHash.decode(a.data.buffer.toOpenArray(a.hpos, a.data.high), ah).isErr:
if MultiHash.decode(
a.data.buffer.toOpenArray(a.hpos, a.data.high), ah).isErr:
return false
if MultiHash.decode(b.data.buffer.toOpenArray(b.hpos, b.data.high), bh).isErr:
if MultiHash.decode(
b.data.buffer.toOpenArray(b.hpos, b.data.high), bh).isErr:
return false
result = (ah == bh)
@@ -263,6 +264,12 @@ proc write*(vb: var VBuffer, cid: Cid) {.inline.} =
## Write CID value ``cid`` to buffer ``vb``.
vb.writeArray(cid.data.buffer)
proc encode*(mbtype: typedesc[MultiBase], encoding: string,
cid: Cid): string {.inline.} =
## Get MultiBase encoded representation of ``cid`` using encoding
## ``encoding``.
result = MultiBase.encode(encoding, cid.data.buffer).tryGet()
proc hash*(cid: Cid): Hash {.inline.} =
hash(cid.data.buffer)
@@ -272,6 +279,9 @@ proc `$`*(cid: Cid): string =
BTCBase58.encode(cid.data.buffer)
elif cid.cidver == CIDv1:
let res = MultiBase.encode("base58btc", cid.data.buffer)
res.get("")
if res.isOk():
res.get()
else:
""
else:
""

View File

@@ -7,11 +7,19 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[tables, sequtils, sets]
import std/[options, tables, sequtils, sets]
import pkg/[chronos, chronicles, metrics]
import peerinfo, peerstore, stream/connection, muxers/muxer, utils/semaphore, errors
import peerinfo,
peerstore,
stream/connection,
muxers/muxer,
utils/semaphore,
errors
logScope:
topics = "libp2p connmanager"
@@ -24,16 +32,14 @@ const
type
TooManyConnectionsError* = object of LPError
AlreadyExpectingConnectionError* = object of LPError
ConnEventKind* {.pure.} = enum
Connected
# A connection was made and securely upgraded - there may be
# more than one concurrent connection thus more than one upgrade
# event per peer.
Disconnected
# Peer disconnected - this event is fired once per upgrade
# when the associated connection is terminated.
Connected, # A connection was made and securely upgraded - there may be
# more than one concurrent connection thus more than one upgrade
# event per peer.
Disconnected # Peer disconnected - this event is fired once per upgrade
# when the associated connection is terminated.
ConnEvent* = object
case kind*: ConnEventKind
@@ -43,31 +49,37 @@ type
discard
ConnEventHandler* =
proc(peerId: PeerId, event: ConnEvent): Future[void] {.gcsafe, raises: [].}
proc(peerId: PeerId, event: ConnEvent): Future[void]
{.gcsafe, raises: [Defect].}
PeerEventKind* {.pure.} = enum
Left
Left,
Identified,
Joined
Identified
PeerEvent* = object
case kind*: PeerEventKind
of PeerEventKind.Joined, PeerEventKind.Identified:
initiator*: bool
else:
discard
of PeerEventKind.Joined:
initiator*: bool
else:
discard
PeerEventHandler* =
proc(peerId: PeerId, event: PeerEvent): Future[void] {.gcsafe, raises: [].}
proc(peerId: PeerId, event: PeerEvent): Future[void] {.gcsafe, raises: [Defect].}
MuxerHolder = object
muxer: Muxer
handle: Future[void]
ConnManager* = ref object of RootObj
maxConnsPerPeer: int
inSema*: AsyncSemaphore
outSema*: AsyncSemaphore
muxed: Table[PeerId, seq[Muxer]]
conns: Table[PeerId, HashSet[Connection]]
muxed: Table[Connection, MuxerHolder]
connEvents: array[ConnEventKind, OrderedSet[ConnEventHandler]]
peerEvents: array[PeerEventKind, OrderedSet[PeerEventHandler]]
expectedConnectionsOverLimit*: Table[(PeerId, Direction), Future[Muxer]]
expectedConnections: Table[PeerId, Future[Connection]]
peerStore*: PeerStore
ConnectionSlot* = object
@@ -77,13 +89,11 @@ type
proc newTooManyConnectionsError(): ref TooManyConnectionsError {.inline.} =
result = newException(TooManyConnectionsError, "Too many connections")
proc new*(
C: type ConnManager,
maxConnsPerPeer = MaxConnectionsPerPeer,
maxConnections = MaxConnections,
maxIn = -1,
maxOut = -1,
): ConnManager =
proc new*(C: type ConnManager,
maxConnsPerPeer = MaxConnectionsPerPeer,
maxConnections = MaxConnections,
maxIn = -1,
maxOut = -1): ConnManager =
var inSema, outSema: AsyncSemaphore
if maxIn > 0 or maxOut > 0:
inSema = newAsyncSemaphore(maxIn)
@@ -94,36 +104,51 @@ proc new*(
else:
raiseAssert "Invalid connection counts!"
C(maxConnsPerPeer: maxConnsPerPeer, inSema: inSema, outSema: outSema)
C(maxConnsPerPeer: maxConnsPerPeer,
inSema: inSema,
outSema: outSema)
proc connCount*(c: ConnManager, peerId: PeerId): int =
c.muxed.getOrDefault(peerId).len
c.conns.getOrDefault(peerId).len
proc connectedPeers*(c: ConnManager, dir: Direction): seq[PeerId] =
var peers = newSeq[PeerId]()
for peerId, mux in c.muxed:
if mux.anyIt(it.connection.dir == dir):
for peerId, conns in c.conns:
if conns.anyIt(it.dir == dir):
peers.add(peerId)
return peers
proc getConnections*(c: ConnManager): Table[PeerId, seq[Muxer]] =
return c.muxed
proc addConnEventHandler*(
c: ConnManager, handler: ConnEventHandler, kind: ConnEventKind
) =
proc addConnEventHandler*(c: ConnManager,
handler: ConnEventHandler,
kind: ConnEventKind) =
## Add peer event handler - handlers must not raise exceptions!
##
if isNil(handler):
return
c.connEvents[kind].incl(handler)
proc removeConnEventHandler*(
c: ConnManager, handler: ConnEventHandler, kind: ConnEventKind
) =
c.connEvents[kind].excl(handler)
try:
if isNil(handler): return
c.connEvents[kind].incl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
proc triggerConnEvent*(c: ConnManager, peerId: PeerId, event: ConnEvent) {.async.} =
raiseAssert exc.msg
proc removeConnEventHandler*(c: ConnManager,
handler: ConnEventHandler,
kind: ConnEventKind) =
try:
c.connEvents[kind].excl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
raiseAssert exc.msg
proc triggerConnEvent*(c: ConnManager,
peerId: PeerId,
event: ConnEvent) {.async, gcsafe.} =
try:
trace "About to trigger connection events", peer = peerId
if c.connEvents[event.kind].len() > 0:
@@ -136,29 +161,54 @@ proc triggerConnEvent*(c: ConnManager, peerId: PeerId, event: ConnEvent) {.async
except CancelledError as exc:
raise exc
except CatchableError as exc:
warn "Exception in triggerConnEvents", msg = exc.msg, peer = peerId, event = $event
warn "Exception in triggerConnEvents",
msg = exc.msg, peer = peerId, event = $event
proc addPeerEventHandler*(
c: ConnManager, handler: PeerEventHandler, kind: PeerEventKind
) =
proc addPeerEventHandler*(c: ConnManager,
handler: PeerEventHandler,
kind: PeerEventKind) =
## Add peer event handler - handlers must not raise exceptions!
##
if isNil(handler):
return
c.peerEvents[kind].incl(handler)
if isNil(handler): return
try:
c.peerEvents[kind].incl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
proc removePeerEventHandler*(
c: ConnManager, handler: PeerEventHandler, kind: PeerEventKind
) =
c.peerEvents[kind].excl(handler)
raiseAssert exc.msg
proc removePeerEventHandler*(c: ConnManager,
handler: PeerEventHandler,
kind: PeerEventKind) =
try:
c.peerEvents[kind].excl(handler)
except Exception as exc:
# TODO: there is an Exception being raised
# somewhere in the depths of the std.
# Might be related to https://github.com/nim-lang/Nim/issues/17382
raiseAssert exc.msg
proc triggerPeerEvents*(c: ConnManager,
peerId: PeerId,
event: PeerEvent) {.async, gcsafe.} =
proc triggerPeerEvents*(c: ConnManager, peerId: PeerId, event: PeerEvent) {.async.} =
trace "About to trigger peer events", peer = peerId
if c.peerEvents[event.kind].len == 0:
return
try:
let count = c.connCount(peerId)
if event.kind == PeerEventKind.Joined and count != 1:
trace "peer already joined", peer = peerId, event = $event
return
elif event.kind == PeerEventKind.Left and count != 0:
trace "peer still connected or already left", peer = peerId, event = $event
return
trace "triggering peer events", peer = peerId, event = $event
var peerEvents: seq[Future[void]]
@@ -171,27 +221,31 @@ proc triggerPeerEvents*(c: ConnManager, peerId: PeerId, event: PeerEvent) {.asyn
except CatchableError as exc: # handlers should not raise!
warn "Exception in triggerPeerEvents", exc = exc.msg, peer = peerId
proc expectConnection*(
c: ConnManager, p: PeerId, dir: Direction
): Future[Muxer] {.async.} =
proc expectConnection*(c: ConnManager, p: PeerId): Future[Connection] {.async.} =
## Wait for a peer to connect to us. This will bypass the `MaxConnectionsPerPeer`
let key = (p, dir)
if key in c.expectedConnectionsOverLimit:
raise newException(
AlreadyExpectingConnectionError,
"Already expecting an incoming connection from that peer",
)
if p in c.expectedConnections:
raise LPError.newException("Already expecting a connection from that peer")
let future = newFuture[Muxer]()
c.expectedConnectionsOverLimit[key] = future
let future = newFuture[Connection]()
c.expectedConnections[p] = future
try:
return await future
finally:
c.expectedConnectionsOverLimit.del(key)
c.expectedConnections.del(p)
proc contains*(c: ConnManager, conn: Connection): bool =
## checks if a connection is being tracked by the
## connection manager
##
if isNil(conn):
return
return conn in c.conns.getOrDefault(conn.peerId)
proc contains*(c: ConnManager, peerId: PeerId): bool =
peerId in c.muxed
peerId in c.conns
proc contains*(c: ConnManager, muxer: Muxer): bool =
## checks if a muxer is being tracked by the connection
@@ -199,149 +253,203 @@ proc contains*(c: ConnManager, muxer: Muxer): bool =
##
if isNil(muxer):
return false
return
let conn = muxer.connection
return muxer in c.muxed.getOrDefault(conn.peerId)
if conn notin c:
return
proc closeMuxer(muxer: Muxer) {.async.} =
trace "Cleaning up muxer", m = muxer
if conn notin c.muxed:
return
await muxer.close()
if not (isNil(muxer.handler)):
return muxer == c.muxed.getOrDefault(conn).muxer
proc closeMuxerHolder(muxerHolder: MuxerHolder) {.async.} =
trace "Cleaning up muxer", m = muxerHolder.muxer
await muxerHolder.muxer.close()
if not(isNil(muxerHolder.handle)):
try:
await muxer.handler # TODO noraises?
await muxerHolder.handle # TODO noraises?
except CatchableError as exc:
trace "Exception in close muxer handler", exc = exc.msg
trace "Cleaned up muxer", m = muxer
trace "Cleaned up muxer", m = muxerHolder.muxer
proc delConn(c: ConnManager, conn: Connection) =
let peerId = conn.peerId
c.conns.withValue(peerId, peerConns):
peerConns[].excl(conn)
if peerConns[].len == 0:
c.conns.del(peerId) # invalidates `peerConns`
libp2p_peers.set(c.conns.len.int64)
trace "Removed connection", conn
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
## clean connection's resources such as muxers and streams
if isNil(conn):
trace "Wont cleanup a nil connection"
return
# Remove connection from all tables without async breaks
var muxer = some(MuxerHolder())
if not c.muxed.pop(conn, muxer.get()):
muxer = none(MuxerHolder)
delConn(c, conn)
proc muxCleanup(c: ConnManager, mux: Muxer) {.async.} =
try:
trace "Triggering disconnect events", mux
let peerId = mux.connection.peerId
if muxer.isSome:
await closeMuxerHolder(muxer.get())
finally:
await conn.close()
let muxers = c.muxed.getOrDefault(peerId).filterIt(it != mux)
if muxers.len > 0:
c.muxed[peerId] = muxers
else:
c.muxed.del(peerId)
libp2p_peers.set(c.muxed.len.int64)
await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left))
trace "Connection cleaned up", conn
if not (c.peerStore.isNil):
c.peerStore.cleanup(peerId)
proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} =
try:
trace "Triggering connect events", conn
conn.upgrade()
await c.triggerConnEvent(peerId, ConnEvent(kind: ConnEventKind.Disconnected))
let peerId = conn.peerId
await c.triggerPeerEvents(
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out))
await c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In))
except CatchableError as exc:
# This is top-level procedure which will work as separate task, so it
# do not need to propagate CancelledError and should handle other errors
warn "Unexpected exception peer cleanup handler", mux, msg = exc.msg
warn "Unexpected exception in switch peer connection cleanup",
conn, msg = exc.msg
proc onClose(c: ConnManager, mux: Muxer) {.async.} =
proc peerCleanup(c: ConnManager, conn: Connection) {.async.} =
try:
trace "Triggering disconnect events", conn
let peerId = conn.peerId
await c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Disconnected))
await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left))
if not(c.peerStore.isNil):
c.peerStore.cleanup(peerId)
except CatchableError as exc:
# This is top-level procedure which will work as separate task, so it
# do not need to propagate CancelledError and should handle other errors
warn "Unexpected exception peer cleanup handler",
conn, msg = exc.msg
proc onClose(c: ConnManager, conn: Connection) {.async.} =
## connection close even handler
##
## triggers the connections resource cleanup
##
try:
await mux.connection.join()
trace "Connection closed, cleaning up", mux
await conn.join()
trace "Connection closed, cleaning up", conn
await c.cleanupConn(conn)
except CancelledError:
# This is top-level procedure which will work as separate task, so it
# do not need to propagate CancelledError.
debug "Unexpected cancellation in connection manager's cleanup", conn
except CatchableError as exc:
debug "Unexpected exception in connection manager's cleanup", errMsg = exc.msg, mux
debug "Unexpected exception in connection manager's cleanup",
errMsg = exc.msg, conn
finally:
await c.muxCleanup(mux)
trace "Triggering peerCleanup", conn
asyncSpawn c.peerCleanup(conn)
proc selectMuxer*(c: ConnManager, peerId: PeerId, dir: Direction): Muxer =
proc selectConn*(c: ConnManager,
peerId: PeerId,
dir: Direction): Connection =
## Select a connection for the provided peer and direction
##
let conns = toSeq(c.muxed.getOrDefault(peerId)).filterIt(it.connection.dir == dir)
let conns = toSeq(
c.conns.getOrDefault(peerId))
.filterIt( it.dir == dir )
if conns.len > 0:
return conns[0]
proc selectMuxer*(c: ConnManager, peerId: PeerId): Muxer =
proc selectConn*(c: ConnManager, peerId: PeerId): Connection =
## Select a connection for the provided giving priority
## to outgoing connections
##
var mux = c.selectMuxer(peerId, Direction.Out)
if isNil(mux):
mux = c.selectMuxer(peerId, Direction.In)
if isNil(mux):
var conn = c.selectConn(peerId, Direction.Out)
if isNil(conn):
conn = c.selectConn(peerId, Direction.In)
if isNil(conn):
trace "connection not found", peerId
return mux
proc storeMuxer*(c: ConnManager, muxer: Muxer) {.raises: [CatchableError].} =
## store the connection and muxer
return conn
proc selectMuxer*(c: ConnManager, conn: Connection): Muxer =
## select the muxer for the provided connection
##
if isNil(muxer):
raise newException(LPError, "muxer cannot be nil")
if isNil(conn):
return
if isNil(muxer.connection):
raise newException(LPError, "muxer's connection cannot be nil")
if conn in c.muxed:
return c.muxed.getOrDefault(conn).muxer
else:
debug "no muxer for connection", conn
if muxer.connection.closed or muxer.connection.atEof:
proc storeConn*(c: ConnManager, conn: Connection)
{.raises: [Defect, LPError].} =
## store a connection
##
if isNil(conn):
raise newException(LPError, "Connection cannot be nil")
if conn.closed or conn.atEof:
raise newException(LPError, "Connection closed or EOF")
let
peerId = muxer.connection.peerId
dir = muxer.connection.dir
let peerId = conn.peerId
# we use getOrDefault in the if below instead of [] to avoid the KeyError
if c.muxed.getOrDefault(peerId).len > c.maxConnsPerPeer:
let key = (peerId, dir)
let expectedConn = c.expectedConnectionsOverLimit.getOrDefault(key)
if expectedConn != nil and not expectedConn.finished:
expectedConn.complete(muxer)
else:
debug "Too many connections for peer", conns = c.muxed.getOrDefault(peerId).len
if peerId in c.expectedConnections and
not(c.expectedConnections.getOrDefault(peerId).finished):
c.expectedConnections.getOrDefault(peerId).complete(conn)
elif c.conns.getOrDefault(peerId).len > c.maxConnsPerPeer:
debug "Too many connections for peer",
conn, conns = c.conns.getOrDefault(peerId).len
raise newTooManyConnectionsError()
raise newTooManyConnectionsError()
var newPeer = false
c.muxed.withValue(peerId, muxers):
doAssert muxers[].len > 0
doAssert muxer notin muxers[]
muxers[].add(muxer)
do:
c.muxed[peerId] = @[muxer]
newPeer = true
libp2p_peers.set(c.muxed.len.int64)
c.conns.mgetOrPut(peerId, HashSet[Connection]()).incl(conn)
libp2p_peers.set(c.conns.len.int64)
asyncSpawn c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: dir == Direction.In)
)
# Launch on close listener
# All the errors are handled inside `onClose()` procedure.
asyncSpawn c.onClose(conn)
if newPeer:
asyncSpawn c.triggerPeerEvents(
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: dir == Direction.Out)
)
asyncSpawn c.onClose(muxer)
trace "Stored muxer", muxer, direction = $muxer.connection.dir, peers = c.muxed.len
trace "Stored connection",
conn, direction = $conn.dir, connections = c.conns.len
proc getIncomingSlot*(c: ConnManager): Future[ConnectionSlot] {.async.} =
await c.inSema.acquire()
return ConnectionSlot(connManager: c, direction: In)
proc getOutgoingSlot*(
c: ConnManager, forceDial = false
): ConnectionSlot {.raises: [TooManyConnectionsError].} =
proc getOutgoingSlot*(c: ConnManager, forceDial = false): ConnectionSlot {.raises: [Defect, TooManyConnectionsError].} =
if forceDial:
c.outSema.forceAcquire()
elif not c.outSema.tryAcquire():
trace "Too many outgoing connections!",
available = c.outSema.count, max = c.outSema.size
trace "Too many outgoing connections!", count = c.outSema.count,
max = c.outSema.size
raise newTooManyConnectionsError()
return ConnectionSlot(connManager: c, direction: Out)
proc slotsAvailable*(c: ConnManager, dir: Direction): int =
case dir
of Direction.In:
return c.inSema.count
of Direction.Out:
return c.outSema.count
case dir:
of Direction.In:
return c.inSema.count
of Direction.Out:
return c.outSema.count
proc release*(cs: ConnectionSlot) =
if cs.direction == In:
@@ -364,41 +472,81 @@ proc trackConnection*(cs: ConnectionSlot, conn: Connection) =
asyncSpawn semaphoreMonitor()
proc trackMuxer*(cs: ConnectionSlot, mux: Muxer) =
if isNil(mux):
cs.release()
return
cs.trackConnection(mux.connection)
proc getStream*(c: ConnManager, muxer: Muxer): Future[Connection] {.async.} =
## get a muxed stream for the passed muxer
proc storeMuxer*(c: ConnManager,
muxer: Muxer,
handle: Future[void] = nil)
{.raises: [Defect, CatchableError].} =
## store the connection and muxer
##
if not (isNil(muxer)):
if isNil(muxer):
raise newException(CatchableError, "muxer cannot be nil")
if isNil(muxer.connection):
raise newException(CatchableError, "muxer's connection cannot be nil")
if muxer.connection notin c:
raise newException(CatchableError, "cant add muxer for untracked connection")
c.muxed[muxer.connection] = MuxerHolder(
muxer: muxer,
handle: handle)
trace "Stored muxer",
muxer, handle = not handle.isNil, connections = c.conns.len
asyncSpawn c.onConnUpgraded(muxer.connection)
proc getStream*(c: ConnManager,
peerId: PeerId,
dir: Direction): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the provided peer
## with the given direction
##
let muxer = c.selectMuxer(c.selectConn(peerId, dir))
if not(isNil(muxer)):
return await muxer.newStream()
proc getStream*(c: ConnManager, peerId: PeerId): Future[Connection] {.async.} =
proc getStream*(c: ConnManager,
peerId: PeerId): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed peer from any connection
##
return await c.getStream(c.selectMuxer(peerId))
let muxer = c.selectMuxer(c.selectConn(peerId))
if not(isNil(muxer)):
return await muxer.newStream()
proc getStream*(
c: ConnManager, peerId: PeerId, dir: Direction
): Future[Connection] {.async.} =
## get a muxed stream for the passed peer from a connection with `dir`
proc getStream*(c: ConnManager,
conn: Connection): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed connection
##
return await c.getStream(c.selectMuxer(peerId, dir))
let muxer = c.selectMuxer(conn)
if not(isNil(muxer)):
return await muxer.newStream()
proc dropPeer*(c: ConnManager, peerId: PeerId) {.async.} =
## drop connections and cleanup resources for peer
##
trace "Dropping peer", peerId
let muxers = c.muxed.getOrDefault(peerId)
let conns = c.conns.getOrDefault(peerId)
for conn in conns:
trace "Removing connection", conn
delConn(c, conn)
var muxers: seq[MuxerHolder]
for conn in conns:
if conn in c.muxed:
muxers.add c.muxed[conn]
c.muxed.del(conn)
for muxer in muxers:
await closeMuxer(muxer)
await closeMuxerHolder(muxer)
for conn in conns:
await conn.close()
trace "Dropped peer", peerId
trace "Peer dropped", peerId
@@ -408,17 +556,24 @@ proc close*(c: ConnManager) {.async.} =
##
trace "Closing ConnManager"
let conns = c.conns
c.conns.clear()
let muxed = c.muxed
c.muxed.clear()
let expected = c.expectedConnectionsOverLimit
c.expectedConnectionsOverLimit.clear()
let expected = c.expectedConnections
c.expectedConnections.clear()
for _, fut in expected:
await fut.cancelAndWait()
for _, muxers in muxed:
for mux in muxers:
await closeMuxer(mux)
for _, muxer in muxed:
await closeMuxerHolder(muxer)
for _, conns2 in conns:
for conn in conns2:
await conn.close()
trace "Closed ConnManager"

View File

@@ -15,11 +15,14 @@
# RFC @ https://tools.ietf.org/html/rfc7539
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import bearssl/blockx
from stew/assign2 import assign
from stew/ptrops import baseAddr
from stew/ranges/ptr_arith import baseAddr
const
ChaChaPolyKeySize = 32
@@ -48,19 +51,17 @@ proc intoChaChaPolyTag*(s: openArray[byte]): ChaChaPolyTag =
# this is reconciled at runtime
# we do this in the global scope / module init
proc encrypt*(
_: type[ChaChaPoly],
key: ChaChaPolyKey,
nonce: ChaChaPolyNonce,
tag: var ChaChaPolyTag,
data: var openArray[byte],
aad: openArray[byte],
) =
let ad =
if aad.len > 0:
unsafeAddr aad[0]
else:
nil
proc encrypt*(_: type[ChaChaPoly],
key: ChaChaPolyKey,
nonce: ChaChaPolyNonce,
tag: var ChaChaPolyTag,
data: var openArray[byte],
aad: openArray[byte]) =
let
ad = if aad.len > 0:
unsafeAddr aad[0]
else:
nil
poly1305CtmulRun(
unsafeAddr key[0],
@@ -71,23 +72,20 @@ proc encrypt*(
uint(aad.len),
baseAddr(tag),
# cast is required to workaround https://github.com/nim-lang/Nim/issues/13905
cast[Chacha20Run](chacha20CtRun), #[encrypt]#
1.cint,
)
cast[Chacha20Run](chacha20CtRun),
#[encrypt]# 1.cint)
proc decrypt*(
_: type[ChaChaPoly],
key: ChaChaPolyKey,
nonce: ChaChaPolyNonce,
tag: var ChaChaPolyTag,
data: var openArray[byte],
aad: openArray[byte],
) =
let ad =
if aad.len > 0:
unsafeAddr aad[0]
else:
nil
proc decrypt*(_: type[ChaChaPoly],
key: ChaChaPolyKey,
nonce: ChaChaPolyNonce,
tag: var ChaChaPolyTag,
data: var openArray[byte],
aad: openArray[byte]) =
let
ad = if aad.len > 0:
unsafeAddr aad[0]
else:
nil
poly1305CtmulRun(
unsafeAddr key[0],
@@ -98,6 +96,5 @@ proc decrypt*(
uint(aad.len),
baseAddr(tag),
# cast is required to workaround https://github.com/nim-lang/Nim/issues/13905
cast[Chacha20Run](chacha20CtRun), #[decrypt]#
0.cint,
)
cast[Chacha20Run](chacha20CtRun),
#[decrypt]# 0.cint)

View File

@@ -8,17 +8,21 @@
# those terms.
## This module implements Public Key and Private Key interface for libp2p.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
from strutils import split, strip, cmpIgnoreCase
const libp2p_pki_schemes* {.strdefine.} = "rsa,ed25519,secp256k1,ecnist"
type PKScheme* = enum
RSA = 0
Ed25519
Secp256k1
ECDSA
type
PKScheme* = enum
RSA = 0,
Ed25519,
Secp256k1,
ECDSA
proc initSupportedSchemes(list: static string): set[PKScheme] =
var res: set[PKScheme]
@@ -64,13 +68,11 @@ when supported(PKScheme.Ed25519):
import ed25519/ed25519
when supported(PKScheme.Secp256k1):
import secp
when supported(PKScheme.ECDSA):
import ecnist
# These used to be declared in `crypto` itself
export ecnist.ephemeral, ecnist.ECDHEScheme
# We are still importing `ecnist` because, it is used for SECIO handshake,
# but it will be impossible to create ECNIST keys or import ECNIST keys.
import bearssl/rand, bearssl/hash as bhash
import ecnist, bearssl/rand, bearssl/hash as bhash
import ../protobuf/minprotobuf, ../vbuffer, ../multihash, ../multicodec
import nimcrypto/[rijndael, twofish, sha2, hash, hmac]
# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
@@ -84,9 +86,11 @@ export rijndael, twofish, sha2, hash, hmac, ncrutils, rand
type
DigestSheme* = enum
Sha256
Sha256,
Sha512
ECDHEScheme* = EcCurveKind
PublicKey* = object
case scheme*: PKScheme
of PKScheme.RSA:
@@ -147,16 +151,15 @@ type
data*: seq[byte]
CryptoError* = enum
KeyError
SigError
HashError
KeyError,
SigError,
HashError,
SchemeError
CryptoResult*[T] = Result[T, CryptoError]
template orError*(exp: untyped, err: untyped): untyped =
exp.mapErr do(_: auto) -> auto:
err
(exp.mapErr do (_: auto) -> auto: err)
proc newRng*(): ref HmacDrbgContext =
# You should only create one instance of the RNG per application / library
@@ -172,9 +175,11 @@ proc newRng*(): ref HmacDrbgContext =
return nil
rng
proc shuffle*[T](rng: ref HmacDrbgContext, x: var openArray[T]) =
if x.len == 0:
return
proc shuffle*[T](
rng: ref HmacDrbgContext,
x: var openArray[T]) =
if x.len == 0: return
var randValues = newSeqUninitialized[byte](len(x) * 2)
hmacDrbgGenerate(rng[], randValues)
@@ -185,12 +190,9 @@ proc shuffle*[T](rng: ref HmacDrbgContext, x: var openArray[T]) =
y = rand mod i
swap(x[i], x[y])
proc random*(
T: typedesc[PrivateKey],
scheme: PKScheme,
rng: var HmacDrbgContext,
bits = RsaDefaultKeySize,
): CryptoResult[PrivateKey] =
proc random*(T: typedesc[PrivateKey], scheme: PKScheme,
rng: var HmacDrbgContext,
bits = RsaDefaultKeySize): CryptoResult[PrivateKey] =
## Generate random private key for scheme ``scheme``.
##
## ``bits`` is number of bits for RSA key, ``bits`` value must be in
@@ -198,7 +200,7 @@ proc random*(
case scheme
of PKScheme.RSA:
when supported(PKScheme.RSA):
let rsakey = ?RsaPrivateKey.random(rng, bits).orError(KeyError)
let rsakey = ? RsaPrivateKey.random(rng, bits).orError(KeyError)
ok(PrivateKey(scheme: scheme, rsakey: rsakey))
else:
err(SchemeError)
@@ -210,7 +212,7 @@ proc random*(
err(SchemeError)
of PKScheme.ECDSA:
when supported(PKScheme.ECDSA):
let eckey = ?ecnist.EcPrivateKey.random(Secp256r1, rng).orError(KeyError)
let eckey = ? ecnist.EcPrivateKey.random(Secp256r1, rng).orError(KeyError)
ok(PrivateKey(scheme: scheme, eckey: eckey))
else:
err(SchemeError)
@@ -221,9 +223,8 @@ proc random*(
else:
err(SchemeError)
proc random*(
T: typedesc[PrivateKey], rng: var HmacDrbgContext, bits = RsaDefaultKeySize
): CryptoResult[PrivateKey] =
proc random*(T: typedesc[PrivateKey], rng: var HmacDrbgContext,
bits = RsaDefaultKeySize): CryptoResult[PrivateKey] =
## Generate random private key using default public-key cryptography scheme.
##
## Default public-key cryptography schemes are following order:
@@ -237,20 +238,17 @@ proc random*(
let skkey = SkPrivateKey.random(rng)
ok(PrivateKey(scheme: PKScheme.Secp256k1, skkey: skkey))
elif supported(PKScheme.RSA):
let rsakey = ?RsaPrivateKey.random(rng, bits).orError(KeyError)
let rsakey = ? RsaPrivateKey.random(rng, bits).orError(KeyError)
ok(PrivateKey(scheme: PKScheme.RSA, rsakey: rsakey))
elif supported(PKScheme.ECDSA):
let eckey = ?ecnist.EcPrivateKey.random(Secp256r1, rng).orError(KeyError)
let eckey = ? ecnist.EcPrivateKey.random(Secp256r1, rng).orError(KeyError)
ok(PrivateKey(scheme: PKScheme.ECDSA, eckey: eckey))
else:
err(SchemeError)
proc random*(
T: typedesc[KeyPair],
scheme: PKScheme,
rng: var HmacDrbgContext,
bits = RsaDefaultKeySize,
): CryptoResult[KeyPair] =
proc random*(T: typedesc[KeyPair], scheme: PKScheme,
rng: var HmacDrbgContext,
bits = RsaDefaultKeySize): CryptoResult[KeyPair] =
## Generate random key pair for scheme ``scheme``.
##
## ``bits`` is number of bits for RSA key, ``bits`` value must be in
@@ -258,52 +256,39 @@ proc random*(
case scheme
of PKScheme.RSA:
when supported(PKScheme.RSA):
let pair = ?RsaKeyPair.random(rng, bits).orError(KeyError)
ok(
KeyPair(
seckey: PrivateKey(scheme: scheme, rsakey: pair.seckey),
pubkey: PublicKey(scheme: scheme, rsakey: pair.pubkey),
)
)
let pair = ? RsaKeyPair.random(rng, bits).orError(KeyError)
ok(KeyPair(
seckey: PrivateKey(scheme: scheme, rsakey: pair.seckey),
pubkey: PublicKey(scheme: scheme, rsakey: pair.pubkey)))
else:
err(SchemeError)
of PKScheme.Ed25519:
when supported(PKScheme.Ed25519):
let pair = EdKeyPair.random(rng)
ok(
KeyPair(
seckey: PrivateKey(scheme: scheme, edkey: pair.seckey),
pubkey: PublicKey(scheme: scheme, edkey: pair.pubkey),
)
)
ok(KeyPair(
seckey: PrivateKey(scheme: scheme, edkey: pair.seckey),
pubkey: PublicKey(scheme: scheme, edkey: pair.pubkey)))
else:
err(SchemeError)
of PKScheme.ECDSA:
when supported(PKScheme.ECDSA):
let pair = ?EcKeyPair.random(Secp256r1, rng).orError(KeyError)
ok(
KeyPair(
seckey: PrivateKey(scheme: scheme, eckey: pair.seckey),
pubkey: PublicKey(scheme: scheme, eckey: pair.pubkey),
)
)
let pair = ? EcKeyPair.random(Secp256r1, rng).orError(KeyError)
ok(KeyPair(
seckey: PrivateKey(scheme: scheme, eckey: pair.seckey),
pubkey: PublicKey(scheme: scheme, eckey: pair.pubkey)))
else:
err(SchemeError)
of PKScheme.Secp256k1:
when supported(PKScheme.Secp256k1):
let pair = SkKeyPair.random(rng)
ok(
KeyPair(
seckey: PrivateKey(scheme: scheme, skkey: pair.seckey),
pubkey: PublicKey(scheme: scheme, skkey: pair.pubkey),
)
)
ok(KeyPair(
seckey: PrivateKey(scheme: scheme, skkey: pair.seckey),
pubkey: PublicKey(scheme: scheme, skkey: pair.pubkey)))
else:
err(SchemeError)
proc random*(
T: typedesc[KeyPair], rng: var HmacDrbgContext, bits = RsaDefaultKeySize
): CryptoResult[KeyPair] =
proc random*(T: typedesc[KeyPair], rng: var HmacDrbgContext,
bits = RsaDefaultKeySize): CryptoResult[KeyPair] =
## Generate random private pair of keys using default public-key cryptography
## scheme.
##
@@ -313,36 +298,24 @@ proc random*(
## So will be used first available (supported) method.
when supported(PKScheme.Ed25519):
let pair = EdKeyPair.random(rng)
ok(
KeyPair(
seckey: PrivateKey(scheme: PKScheme.Ed25519, edkey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.Ed25519, edkey: pair.pubkey),
)
)
ok(KeyPair(
seckey: PrivateKey(scheme: PKScheme.Ed25519, edkey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.Ed25519, edkey: pair.pubkey)))
elif supported(PKScheme.Secp256k1):
let pair = SkKeyPair.random(rng)
ok(
KeyPair(
seckey: PrivateKey(scheme: PKScheme.Secp256k1, skkey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.Secp256k1, skkey: pair.pubkey),
)
)
ok(KeyPair(
seckey: PrivateKey(scheme: PKScheme.Secp256k1, skkey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.Secp256k1, skkey: pair.pubkey)))
elif supported(PKScheme.RSA):
let pair = ?RsaKeyPair.random(rng, bits).orError(KeyError)
ok(
KeyPair(
seckey: PrivateKey(scheme: PKScheme.RSA, rsakey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.RSA, rsakey: pair.pubkey),
)
)
let pair = ? RsaKeyPair.random(rng, bits).orError(KeyError)
ok(KeyPair(
seckey: PrivateKey(scheme: PKScheme.RSA, rsakey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.RSA, rsakey: pair.pubkey)))
elif supported(PKScheme.ECDSA):
let pair = ?EcKeyPair.random(Secp256r1, rng).orError(KeyError)
ok(
KeyPair(
seckey: PrivateKey(scheme: PKScheme.ECDSA, eckey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.ECDSA, eckey: pair.pubkey),
)
)
let pair = ? EcKeyPair.random(Secp256r1, rng).orError(KeyError)
ok(KeyPair(
seckey: PrivateKey(scheme: PKScheme.ECDSA, eckey: pair.seckey),
pubkey: PublicKey(scheme: PKScheme.ECDSA, eckey: pair.pubkey)))
else:
err(SchemeError)
@@ -363,7 +336,7 @@ proc getPublicKey*(key: PrivateKey): CryptoResult[PublicKey] =
err(SchemeError)
of PKScheme.ECDSA:
when supported(PKScheme.ECDSA):
let eckey = ?key.eckey.getPublicKey().orError(KeyError)
let eckey = ? key.eckey.getPublicKey().orError(KeyError)
ok(PublicKey(scheme: ECDSA, eckey: eckey))
else:
err(SchemeError)
@@ -374,9 +347,8 @@ proc getPublicKey*(key: PrivateKey): CryptoResult[PublicKey] =
else:
err(SchemeError)
proc toRawBytes*(
key: PrivateKey | PublicKey, data: var openArray[byte]
): CryptoResult[int] =
proc toRawBytes*(key: PrivateKey | PublicKey,
data: var openArray[byte]): CryptoResult[int] =
## Serialize private key ``key`` (using scheme's own serialization) and store
## it to ``data``.
##
@@ -435,7 +407,7 @@ proc toBytes*(key: PrivateKey, data: var openArray[byte]): CryptoResult[int] =
## Returns number of bytes (octets) needed to store private key ``key``.
var msg = initProtoBuffer()
msg.write(1, uint64(key.scheme))
msg.write(2, ?key.getRawBytes())
msg.write(2, ? key.getRawBytes())
msg.finish()
var blen = len(msg.buffer)
if len(data) >= blen:
@@ -449,7 +421,7 @@ proc toBytes*(key: PublicKey, data: var openArray[byte]): CryptoResult[int] =
## Returns number of bytes (octets) needed to store public key ``key``.
var msg = initProtoBuffer()
msg.write(1, uint64(key.scheme))
msg.write(2, ?key.getRawBytes())
msg.write(2, ? key.getRawBytes())
msg.finish()
var blen = len(msg.buffer)
if len(data) >= blen and blen > 0:
@@ -469,7 +441,7 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] =
## serialization).
var msg = initProtoBuffer()
msg.write(1, uint64(key.scheme))
msg.write(2, ?key.getRawBytes())
msg.write(2, ? key.getRawBytes())
msg.finish()
ok(msg.buffer)
@@ -478,7 +450,7 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] =
## serialization).
var msg = initProtoBuffer()
msg.write(1, uint64(key.scheme))
msg.write(2, ?key.getRawBytes())
msg.write(2, ? key.getRawBytes())
msg.finish()
ok(msg.buffer)
@@ -486,7 +458,7 @@ proc getBytes*(sig: Signature): seq[byte] =
## Return signature ``sig`` in binary form.
result = sig.data
template initImpl[T: PrivateKey | PublicKey](key: var T, data: openArray[byte]): bool =
proc init*[T: PrivateKey|PublicKey](key: var T, data: openArray[byte]): bool =
## Initialize private key ``key`` from libp2p's protobuf serialized raw
## binary form.
##
@@ -499,7 +471,7 @@ template initImpl[T: PrivateKey | PublicKey](key: var T, data: openArray[byte]):
var pb = initProtoBuffer(@data)
let r1 = pb.getField(1, id)
let r2 = pb.getField(2, buffer)
if not (r1.get(false) and r2.get(false)):
if not(r1.isOk() and r1.get() and r2.isOk() and r2.get()):
false
else:
if cast[int8](id) notin SupportedSchemesInt or len(buffer) <= 0:
@@ -510,7 +482,7 @@ template initImpl[T: PrivateKey | PublicKey](key: var T, data: openArray[byte]):
var nkey = PrivateKey(scheme: scheme)
else:
var nkey = PublicKey(scheme: scheme)
case scheme
case scheme:
of PKScheme.RSA:
when supported(PKScheme.RSA):
if init(nkey.rsakey, buffer).isOk:
@@ -548,15 +520,6 @@ template initImpl[T: PrivateKey | PublicKey](key: var T, data: openArray[byte]):
else:
false
{.push warning[ProveField]: off.} # https://github.com/nim-lang/Nim/issues/22060
proc init*(key: var PrivateKey, data: openArray[byte]): bool =
initImpl(key, data)
proc init*(key: var PublicKey, data: openArray[byte]): bool =
initImpl(key, data)
{.pop.}
proc init*(sig: var Signature, data: openArray[byte]): bool =
## Initialize signature ``sig`` from raw binary form.
##
@@ -565,7 +528,7 @@ proc init*(sig: var Signature, data: openArray[byte]): bool =
sig.data = @data
result = true
proc init*[T: PrivateKey | PublicKey](key: var T, data: string): bool =
proc init*[T: PrivateKey|PublicKey](key: var T, data: string): bool =
## Initialize private/public key ``key`` from libp2p's protobuf serialized
## hexadecimal string representation.
##
@@ -579,7 +542,8 @@ proc init*(sig: var Signature, data: string): bool =
## Returns ``true`` on success.
sig.init(ncrutils.fromHex(data))
proc init*(t: typedesc[PrivateKey], data: openArray[byte]): CryptoResult[PrivateKey] =
proc init*(t: typedesc[PrivateKey],
data: openArray[byte]): CryptoResult[PrivateKey] =
## Create new private key from libp2p's protobuf serialized binary form.
var res: t
if not res.init(data):
@@ -587,7 +551,8 @@ proc init*(t: typedesc[PrivateKey], data: openArray[byte]): CryptoResult[Private
else:
ok(res)
proc init*(t: typedesc[PublicKey], data: openArray[byte]): CryptoResult[PublicKey] =
proc init*(t: typedesc[PublicKey],
data: openArray[byte]): CryptoResult[PublicKey] =
## Create new public key from libp2p's protobuf serialized binary form.
var res: t
if not res.init(data):
@@ -595,7 +560,8 @@ proc init*(t: typedesc[PublicKey], data: openArray[byte]): CryptoResult[PublicKe
else:
ok(res)
proc init*(t: typedesc[Signature], data: openArray[byte]): CryptoResult[Signature] =
proc init*(t: typedesc[Signature],
data: openArray[byte]): CryptoResult[Signature] =
## Create new public key from libp2p's protobuf serialized binary form.
var res: t
if not res.init(data):
@@ -611,28 +577,24 @@ proc init*(t: typedesc[PrivateKey], data: string): CryptoResult[PrivateKey] =
when supported(PKScheme.RSA):
proc init*(t: typedesc[PrivateKey], key: rsa.RsaPrivateKey): PrivateKey =
PrivateKey(scheme: RSA, rsakey: key)
proc init*(t: typedesc[PublicKey], key: rsa.RsaPublicKey): PublicKey =
PublicKey(scheme: RSA, rsakey: key)
when supported(PKScheme.Ed25519):
proc init*(t: typedesc[PrivateKey], key: EdPrivateKey): PrivateKey =
PrivateKey(scheme: Ed25519, edkey: key)
proc init*(t: typedesc[PublicKey], key: EdPublicKey): PublicKey =
PublicKey(scheme: Ed25519, edkey: key)
when supported(PKScheme.Secp256k1):
proc init*(t: typedesc[PrivateKey], key: SkPrivateKey): PrivateKey =
PrivateKey(scheme: Secp256k1, skkey: key)
proc init*(t: typedesc[PublicKey], key: SkPublicKey): PublicKey =
PublicKey(scheme: Secp256k1, skkey: key)
when supported(PKScheme.ECDSA):
proc init*(t: typedesc[PrivateKey], key: ecnist.EcPrivateKey): PrivateKey =
PrivateKey(scheme: ECDSA, eckey: key)
proc init*(t: typedesc[PublicKey], key: ecnist.EcPublicKey): PublicKey =
PublicKey(scheme: ECDSA, eckey: key)
@@ -701,9 +663,9 @@ proc `==`*(key1, key2: PrivateKey): bool =
else:
false
proc `$`*(key: PrivateKey | PublicKey): string =
proc `$`*(key: PrivateKey|PublicKey): string =
## Get string representation of private/public key ``key``.
case key.scheme
case key.scheme:
of PKScheme.RSA:
when supported(PKScheme.RSA):
$(key.rsakey)
@@ -725,9 +687,9 @@ proc `$`*(key: PrivateKey | PublicKey): string =
else:
"unsupported secp256k1 key"
func shortLog*(key: PrivateKey | PublicKey): string =
func shortLog*(key: PrivateKey|PublicKey): string =
## Get short string representation of private/public key ``key``.
case key.scheme
case key.scheme:
of PKScheme.RSA:
when supported(PKScheme.RSA):
($key.rsakey).shortLog
@@ -753,15 +715,16 @@ proc `$`*(sig: Signature): string =
## Get string representation of signature ``sig``.
result = ncrutils.toHex(sig.data)
proc sign*(key: PrivateKey, data: openArray[byte]): CryptoResult[Signature] {.gcsafe.} =
proc sign*(key: PrivateKey,
data: openArray[byte]): CryptoResult[Signature] {.gcsafe.} =
## Sign message ``data`` using private key ``key`` and return generated
## signature in raw binary form.
var res: Signature
case key.scheme
case key.scheme:
of PKScheme.RSA:
when supported(PKScheme.RSA):
let sig = ?key.rsakey.sign(data).orError(SigError)
res.data = ?sig.getBytes().orError(SigError)
let sig = ? key.rsakey.sign(data).orError(SigError)
res.data = ? sig.getBytes().orError(SigError)
ok(res)
else:
err(SchemeError)
@@ -774,8 +737,8 @@ proc sign*(key: PrivateKey, data: openArray[byte]): CryptoResult[Signature] {.gc
err(SchemeError)
of PKScheme.ECDSA:
when supported(PKScheme.ECDSA):
let sig = ?key.eckey.sign(data).orError(SigError)
res.data = ?sig.getBytes().orError(SigError)
let sig = ? key.eckey.sign(data).orError(SigError)
res.data = ? sig.getBytes().orError(SigError)
ok(res)
else:
err(SchemeError)
@@ -790,7 +753,7 @@ proc sign*(key: PrivateKey, data: openArray[byte]): CryptoResult[Signature] {.gc
proc verify*(sig: Signature, message: openArray[byte], key: PublicKey): bool =
## Verify signature ``sig`` using message ``message`` and public key ``key``.
## Return ``true`` if message signature is valid.
case key.scheme
case key.scheme:
of PKScheme.RSA:
when supported(PKScheme.RSA):
var signature: RsaSignature
@@ -828,12 +791,12 @@ proc verify*(sig: Signature, message: openArray[byte], key: PublicKey): bool =
else:
false
template makeSecret(buffer, hmactype, secret, seed: untyped) {.dirty.} =
template makeSecret(buffer, hmactype, secret, seed: untyped) {.dirty.}=
var ctx: hmactype
var j = 0
# We need to strip leading zeros, because Go bigint serialization do it.
var offset = 0
for i in 0 ..< len(secret):
for i in 0..<len(secret):
if secret[i] != 0x00'u8:
break
inc(offset)
@@ -854,9 +817,8 @@ template makeSecret(buffer, hmactype, secret, seed: untyped) {.dirty.} =
ctx.update(a.data)
a = ctx.finish()
proc stretchKeys*(
cipherType: string, hashType: string, sharedSecret: seq[byte]
): Secret =
proc stretchKeys*(cipherType: string, hashType: string,
sharedSecret: seq[byte]): Secret =
## Expand shared secret to cryptographic keys.
if cipherType == "AES-128":
result.ivsize = aes128.sizeBlock
@@ -882,57 +844,65 @@ template goffset*(secret, id, o: untyped): untyped =
id * (len(secret.data) shr 1) + o
template ivOpenArray*(secret: Secret, id: int): untyped =
toOpenArray(
secret.data, goffset(secret, id, 0), goffset(secret, id, secret.ivsize - 1)
)
toOpenArray(secret.data, goffset(secret, id, 0),
goffset(secret, id, secret.ivsize - 1))
template keyOpenArray*(secret: Secret, id: int): untyped =
toOpenArray(
secret.data,
goffset(secret, id, secret.ivsize),
goffset(secret, id, secret.ivsize + secret.keysize - 1),
)
toOpenArray(secret.data, goffset(secret, id, secret.ivsize),
goffset(secret, id, secret.ivsize + secret.keysize - 1))
template macOpenArray*(secret: Secret, id: int): untyped =
toOpenArray(
secret.data,
goffset(secret, id, secret.ivsize + secret.keysize),
goffset(secret, id, secret.ivsize + secret.keysize + secret.macsize - 1),
)
toOpenArray(secret.data, goffset(secret, id, secret.ivsize + secret.keysize),
goffset(secret, id, secret.ivsize + secret.keysize + secret.macsize - 1))
proc iv*(secret: Secret, id: int): seq[byte] {.inline.} =
## Get array of bytes with with initial vector.
result = newSeq[byte](secret.ivsize)
var offset =
if id == 0:
0
else:
(len(secret.data) div 2)
var offset = if id == 0: 0 else: (len(secret.data) div 2)
copyMem(addr result[0], unsafeAddr secret.data[offset], secret.ivsize)
proc key*(secret: Secret, id: int): seq[byte] {.inline.} =
result = newSeq[byte](secret.keysize)
var offset =
if id == 0:
0
else:
(len(secret.data) div 2)
var offset = if id == 0: 0 else: (len(secret.data) div 2)
offset += secret.ivsize
copyMem(addr result[0], unsafeAddr secret.data[offset], secret.keysize)
proc mac*(secret: Secret, id: int): seq[byte] {.inline.} =
result = newSeq[byte](secret.macsize)
var offset =
if id == 0:
0
else:
(len(secret.data) div 2)
var offset = if id == 0: 0 else: (len(secret.data) div 2)
offset += secret.ivsize + secret.keysize
copyMem(addr result[0], unsafeAddr secret.data[offset], secret.macsize)
proc getOrder*(
remotePubkey, localNonce: openArray[byte], localPubkey, remoteNonce: openArray[byte]
): CryptoResult[int] =
proc ephemeral*(
scheme: ECDHEScheme,
rng: var HmacDrbgContext): CryptoResult[EcKeyPair] =
## Generate ephemeral keys used to perform ECDHE.
var keypair: EcKeyPair
if scheme == Secp256r1:
keypair = ? EcKeyPair.random(Secp256r1, rng).orError(KeyError)
elif scheme == Secp384r1:
keypair = ? EcKeyPair.random(Secp384r1, rng).orError(KeyError)
elif scheme == Secp521r1:
keypair = ? EcKeyPair.random(Secp521r1, rng).orError(KeyError)
ok(keypair)
proc ephemeral*(
scheme: string, rng: var HmacDrbgContext): CryptoResult[EcKeyPair] =
## Generate ephemeral keys used to perform ECDHE using string encoding.
##
## Currently supported encoding strings are P-256, P-384, P-521, if encoding
## string is not supported P-521 key will be generated.
if scheme == "P-256":
ephemeral(Secp256r1, rng)
elif scheme == "P-384":
ephemeral(Secp384r1, rng)
elif scheme == "P-521":
ephemeral(Secp521r1, rng)
else:
ephemeral(Secp521r1, rng)
proc getOrder*(remotePubkey, localNonce: openArray[byte],
localPubkey, remoteNonce: openArray[byte]): CryptoResult[int] =
## Compare values and calculate `order` parameter.
var ctx: sha256
ctx.init()
@@ -943,9 +913,9 @@ proc getOrder*(
ctx.update(localPubkey)
ctx.update(remoteNonce)
var digest2 = ctx.finish()
var mh1 = ?MultiHash.init(multiCodec("sha2-256"), digest1).orError(HashError)
var mh2 = ?MultiHash.init(multiCodec("sha2-256"), digest2).orError(HashError)
var res = 0
var mh1 = ? MultiHash.init(multiCodec("sha2-256"), digest1).orError(HashError)
var mh2 = ? MultiHash.init(multiCodec("sha2-256"), digest2).orError(HashError)
var res = 0;
for i in 0 ..< len(mh1.data.buffer):
res = int(mh1.data.buffer[i]) - int(mh2.data.buffer[i])
if res != 0:
@@ -976,45 +946,96 @@ proc selectBest*(order: int, p1, p2: string): string =
if felement == selement:
return felement
proc createProposal*(nonce, pubkey: openArray[byte],
exchanges, ciphers, hashes: string): seq[byte] =
## Create SecIO proposal message using random ``nonce``, local public key
## ``pubkey``, comma-delimieted list of supported exchange schemes
## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and
## comma-delimeted list of supported hashes ``hashes``.
var msg = initProtoBuffer({WithUint32BeLength})
msg.write(1, nonce)
msg.write(2, pubkey)
msg.write(3, exchanges)
msg.write(4, ciphers)
msg.write(5, hashes)
msg.finish()
msg.buffer
proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
exchanges, ciphers, hashes: var string): bool =
## Parse incoming proposal message and decode remote random nonce ``nonce``,
## remote public key ``pubkey``, comma-delimieted list of supported exchange
## schemes ``exchanges``, comma-delimeted list of supported ciphers
## ``ciphers`` and comma-delimeted list of supported hashes ``hashes``.
##
## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message)
let r1 = pb.getField(1, nonce)
let r2 = pb.getField(2, pubkey)
let r3 = pb.getField(3, exchanges)
let r4 = pb.getField(4, ciphers)
let r5 = pb.getField(5, hashes)
r1.isOk() and r1.get() and r2.isOk() and r2.get() and
r3.isOk() and r3.get() and r4.isOk() and r4.get() and
r5.isOk() and r5.get()
proc createExchange*(epubkey, signature: openArray[byte]): seq[byte] =
## Create SecIO exchange message using ephemeral public key ``epubkey`` and
## signature of proposal blocks ``signature``.
var msg = initProtoBuffer({WithUint32BeLength})
msg.write(1, epubkey)
msg.write(2, signature)
msg.finish()
msg.buffer
proc decodeExchange*(message: seq[byte],
pubkey, signature: var seq[byte]): bool =
## Parse incoming exchange message and decode remote ephemeral public key
## ``pubkey`` and signature ``signature``.
##
## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message)
let r1 = pb.getField(1, pubkey)
let r2 = pb.getField(2, signature)
r1.isOk() and r1.get() and r2.isOk() and r2.get()
## Serialization/Deserialization helpers
proc write*(
vb: var VBuffer, pubkey: PublicKey
) {.inline, raises: [ResultError[CryptoError]].} =
proc write*(vb: var VBuffer, pubkey: PublicKey) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
## Write PublicKey value ``pubkey`` to buffer ``vb``.
vb.writeSeq(pubkey.getBytes().tryGet())
proc write*(
vb: var VBuffer, seckey: PrivateKey
) {.inline, raises: [ResultError[CryptoError]].} =
proc write*(vb: var VBuffer, seckey: PrivateKey) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
## Write PrivateKey value ``seckey`` to buffer ``vb``.
vb.writeSeq(seckey.getBytes().tryGet())
proc write*(
vb: var VBuffer, sig: PrivateKey
) {.inline, raises: [ResultError[CryptoError]].} =
proc write*(vb: var VBuffer, sig: PrivateKey) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
## Write Signature value ``sig`` to buffer ``vb``.
vb.writeSeq(sig.getBytes().tryGet())
proc write*[T: PublicKey | PrivateKey](
pb: var ProtoBuffer, field: int, key: T
) {.inline, raises: [ResultError[CryptoError]].} =
proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int,
key: T) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
write(pb, field, key.getBytes().tryGet())
proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {.inline, raises: [].} =
proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {.
inline, raises: [Defect].} =
write(pb, field, sig.getBytes())
proc getField*[T: PublicKey | PrivateKey](
pb: ProtoBuffer, field: int, value: var T
): ProtoResult[bool] =
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
value: var T): ProtoResult[bool] =
## Deserialize public/private key from protobuf's message ``pb`` using field
## index ``field``.
##
## On success deserialized key will be stored in ``value``.
var buffer: seq[byte]
var key: T
let res = ?pb.getField(field, buffer)
if not (res):
let res = ? pb.getField(field, buffer)
if not(res):
ok(false)
else:
if key.init(buffer):
@@ -1023,15 +1044,16 @@ proc getField*[T: PublicKey | PrivateKey](
else:
err(ProtoError.IncorrectBlob)
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): ProtoResult[bool] =
proc getField*(pb: ProtoBuffer, field: int,
value: var Signature): ProtoResult[bool] =
## Deserialize signature from protobuf's message ``pb`` using field index
## ``field``.
##
## On success deserialized signature will be stored in ``value``.
var buffer: seq[byte]
var sig: Signature
let res = ?pb.getField(field, buffer)
if not (res):
let res = ? pb.getField(field, buffer)
if not(res):
ok(false)
else:
if sig.init(buffer):

View File

@@ -15,14 +15,18 @@
# RFC @ https://tools.ietf.org/html/rfc7748
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import bearssl/[ec, rand]
import bearssl/[ec, rand, hash]
import stew/results
from stew/assign2 import assign
export results
const Curve25519KeySize* = 32
const
Curve25519KeySize* = 32
type
Curve25519* = object
@@ -34,12 +38,12 @@ proc intoCurve25519Key*(s: openArray[byte]): Curve25519Key =
assert s.len == Curve25519KeySize
assign(result, s)
proc getBytes*(key: Curve25519Key): seq[byte] =
@key
proc getBytes*(key: Curve25519Key): seq[byte] = @key
proc byteswap(buf: var Curve25519Key) {.inline.} =
for i in 0 ..< 16:
let x = buf[i]
for i in 0..<16:
let
x = buf[i]
buf[i] = buf[31 - i]
buf[31 - i] = x
@@ -47,25 +51,31 @@ proc mul*(_: type[Curve25519], point: var Curve25519Key, multiplier: Curve25519K
let defaultBrEc = ecGetDefault()
# multiplier needs to be big-endian
var multiplierBs = multiplier
var
multiplierBs = multiplier
multiplierBs.byteswap()
let res = defaultBrEc.mul(
addr point[0],
Curve25519KeySize,
addr multiplierBs[0],
Curve25519KeySize,
EC_curve25519,
)
let
res = defaultBrEc.mul(
addr point[0],
Curve25519KeySize,
addr multiplierBs[0],
Curve25519KeySize,
EC_curve25519)
assert res == 1
proc mulgen(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) =
let defaultBrEc = ecGetDefault()
var rpoint = point
var
rpoint = point
rpoint.byteswap()
let size =
defaultBrEc.mulgen(addr dst[0], addr rpoint[0], Curve25519KeySize, EC_curve25519)
let
size = defaultBrEc.mulgen(
addr dst[0],
addr rpoint[0],
Curve25519KeySize,
EC_curve25519)
assert size == Curve25519KeySize
@@ -75,7 +85,8 @@ proc public*(private: Curve25519Key): Curve25519Key =
proc random*(_: type[Curve25519Key], rng: var HmacDrbgContext): Curve25519Key =
var res: Curve25519Key
let defaultBrEc = ecGetDefault()
let len = ecKeygen(addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519)
let len = ecKeygen(
addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519)
# Per bearssl documentation, the keygen only fails if the curve is
# unrecognised -
doAssert len == Curve25519KeySize, "Could not generate curve"

View File

@@ -14,7 +14,10 @@
## BearSSL library <https://bearssl.org/>
## Copyright(C) 2018 Thomas Pornin <pornin@bolet.org>.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import bearssl/[ec, rand, hash]
# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
@@ -22,9 +25,6 @@ import nimcrypto/utils as ncrutils
import minasn1
export minasn1.Asn1Error
import stew/[results, ctops]
import ../utility
export results
const
@@ -58,22 +58,23 @@ type
buffer*: seq[byte]
EcCurveKind* = enum
Secp256r1 = EC_secp256r1
Secp384r1 = EC_secp384r1
Secp256r1 = EC_secp256r1,
Secp384r1 = EC_secp384r1,
Secp521r1 = EC_secp521r1
EcPKI* = EcPrivateKey | EcPublicKey | EcSignature
EcError* = enum
EcRngError
EcKeyGenError
EcPublicKeyError
EcKeyIncorrectError
EcRngError,
EcKeyGenError,
EcPublicKeyError,
EcKeyIncorrectError,
EcSignatureError
EcResult*[T] = Result[T, EcError]
const EcSupportedCurvesCint* = @[cint(Secp256r1), cint(Secp384r1), cint(Secp521r1)]
const
EcSupportedCurvesCint* = {cint(Secp256r1), cint(Secp384r1), cint(Secp521r1)}
proc `-`(x: uint32): uint32 {.inline.} =
result = (0xFFFF_FFFF'u32 - x) + 1'u32
@@ -87,7 +88,7 @@ proc CMP(x, y: uint32): int32 {.inline.} =
proc EQ0(x: int32): uint32 {.inline.} =
var q = cast[uint32](x)
result = not (q or -q) shr 31
result = not(q or -q) shr 31
proc NEQ(x, y: uint32): uint32 {.inline.} =
var q = cast[uint32](x xor y)
@@ -112,7 +113,7 @@ proc checkScalar(scalar: openArray[byte], curve: cint): uint32 =
for u in scalar:
z = z or u
if len(scalar) == int(orderlen):
for i in 0 ..< len(scalar):
for i in 0..<len(scalar):
c = c or (-(cast[int32](EQ0(c))) and CMP(scalar[i], order[i]))
else:
c = -1
@@ -125,7 +126,8 @@ proc checkPublic(key: openArray[byte], curve: cint): uint32 =
var impl = ecGetDefault()
var orderlen: uint = 0
discard impl.order(curve, orderlen)
result = impl.mul(unsafeAddr ckey[0], uint(len(ckey)), addr x[0], uint(len(x)), curve)
result = impl.mul(unsafeAddr ckey[0], uint(len(ckey)),
addr x[0], uint(len(x)), curve)
proc getOffset(pubkey: EcPublicKey): int {.inline.} =
let o = cast[uint](pubkey.key.q) - cast[uint](unsafeAddr pubkey.buffer[0])
@@ -143,15 +145,21 @@ proc getOffset(seckey: EcPrivateKey): int {.inline.} =
template getPublicKeyLength*(curve: EcCurveKind): int =
case curve
of Secp256r1: PubKey256Length
of Secp384r1: PubKey384Length
of Secp521r1: PubKey521Length
of Secp256r1:
PubKey256Length
of Secp384r1:
PubKey384Length
of Secp521r1:
PubKey521Length
template getPrivateKeyLength*(curve: EcCurveKind): int =
case curve
of Secp256r1: SecKey256Length
of Secp384r1: SecKey384Length
of Secp521r1: SecKey521Length
of Secp256r1:
SecKey256Length
of Secp384r1:
SecKey384Length
of Secp521r1:
SecKey521Length
proc copy*[T: EcPKI](dst: var T, src: T): bool =
## Copy EC `private key`, `public key` or `signature` ``src`` to ``dst``.
@@ -193,7 +201,7 @@ proc copy*[T: EcPKI](src: T): T {.inline.} =
if not copy(result, src):
raise newException(EcKeyIncorrectError, "Incorrect key or signature")
proc clear*[T: EcPKI | EcKeyPair](pki: var T) =
proc clear*[T: EcPKI|EcKeyPair](pki: var T) =
## Wipe and clear EC `private key`, `public key` or `signature` object.
doAssert(not isNil(pki))
when T is EcPrivateKey:
@@ -224,8 +232,8 @@ proc clear*[T: EcPKI | EcKeyPair](pki: var T) =
pki.pubkey.key.curve = 0
proc random*(
T: typedesc[EcPrivateKey], kind: EcCurveKind, rng: var HmacDrbgContext
): EcResult[EcPrivateKey] =
T: typedesc[EcPrivateKey], kind: EcCurveKind,
rng: var HmacDrbgContext): EcResult[EcPrivateKey] =
## Generate new random EC private key using BearSSL's HMAC-SHA256-DRBG
## algorithm.
##
@@ -233,9 +241,9 @@ proc random*(
## secp521r1).
var ecimp = ecGetDefault()
var res = new EcPrivateKey
if ecKeygen(
addr rng.vtable, ecimp, addr res.key, addr res.buffer[0], safeConvert[cint](kind)
) == 0:
if ecKeygen(addr rng.vtable, ecimp,
addr res.key, addr res.buffer[0],
cast[cint](kind)) == 0:
err(EcKeyGenError)
else:
ok(res)
@@ -249,7 +257,8 @@ proc getPublicKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] =
if seckey.key.curve in EcSupportedCurvesCint:
var res = new EcPublicKey
assert res.buffer.len > getPublicKeyLength(cast[EcCurveKind](seckey.key.curve))
if ecComputePub(ecimp, addr res.key, addr res.buffer[0], unsafeAddr seckey.key) == 0:
if ecComputePub(ecimp, addr res.key,
addr res.buffer[0], unsafeAddr seckey.key) == 0:
err(EcKeyIncorrectError)
else:
ok(res)
@@ -257,23 +266,23 @@ proc getPublicKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] =
err(EcKeyIncorrectError)
proc random*(
T: typedesc[EcKeyPair], kind: EcCurveKind, rng: var HmacDrbgContext
): EcResult[T] =
T: typedesc[EcKeyPair], kind: EcCurveKind,
rng: var HmacDrbgContext): EcResult[T] =
## Generate new random EC private and public keypair using BearSSL's
## HMAC-SHA256-DRBG algorithm.
##
## ``kind`` elliptic curve kind of your choice (secp256r1, secp384r1 or
## secp521r1).
let
seckey = ?EcPrivateKey.random(kind, rng)
pubkey = ?seckey.getPublicKey()
seckey = ? EcPrivateKey.random(kind, rng)
pubkey = ? seckey.getPublicKey()
key = EcKeyPair(seckey: seckey, pubkey: pubkey)
ok(key)
proc `$`*(seckey: EcPrivateKey): string =
## Return string representation of EC private key.
if isNil(seckey) or seckey.key.curve == 0 or seckey.key.xlen == 0 or
len(seckey.buffer) == 0:
len(seckey.buffer) == 0:
result = "Empty or uninitialized ECNIST key"
else:
if seckey.key.curve notin EcSupportedCurvesCint:
@@ -289,7 +298,7 @@ proc `$`*(seckey: EcPrivateKey): string =
proc `$`*(pubkey: EcPublicKey): string =
## Return string representation of EC public key.
if isNil(pubkey) or pubkey.key.curve == 0 or pubkey.key.qlen == 0 or
len(pubkey.buffer) == 0:
len(pubkey.buffer) == 0:
result = "Empty or uninitialized ECNIST key"
else:
if pubkey.key.curve notin EcSupportedCurvesCint:
@@ -362,7 +371,7 @@ proc toBytes*(seckey: EcPrivateKey, data: var openArray[byte]): EcResult[int] =
return err(EcKeyIncorrectError)
if seckey.key.curve in EcSupportedCurvesCint:
var offset, length: int
var pubkey = ?seckey.getPublicKey()
var pubkey = ? seckey.getPublicKey()
var b = Asn1Buffer.init()
var p = Asn1Composite.init(Asn1Tag.Sequence)
var c0 = Asn1Composite.init(0)
@@ -378,14 +387,16 @@ proc toBytes*(seckey: EcPrivateKey, data: var openArray[byte]): EcResult[int] =
if offset < 0:
return err(EcKeyIncorrectError)
length = int(pubkey.key.qlen)
c1.write(Asn1Tag.BitString, pubkey.buffer.toOpenArray(offset, offset + length - 1))
c1.write(Asn1Tag.BitString,
pubkey.buffer.toOpenArray(offset, offset + length - 1))
c1.finish()
offset = seckey.getOffset()
if offset < 0:
return err(EcKeyIncorrectError)
length = int(seckey.key.xlen)
p.write(1'u64)
p.write(Asn1Tag.OctetString, seckey.buffer.toOpenArray(offset, offset + length - 1))
p.write(Asn1Tag.OctetString,
seckey.buffer.toOpenArray(offset, offset + length - 1))
p.write(c0)
p.write(c1)
p.finish()
@@ -399,6 +410,7 @@ proc toBytes*(seckey: EcPrivateKey, data: var openArray[byte]): EcResult[int] =
else:
err(EcKeyIncorrectError)
proc toBytes*(pubkey: EcPublicKey, data: var openArray[byte]): EcResult[int] =
## Serialize EC public key ``pubkey`` to ASN.1 DER binary form and store it
## to ``data``.
@@ -424,7 +436,8 @@ proc toBytes*(pubkey: EcPublicKey, data: var openArray[byte]): EcResult[int] =
if offset < 0:
return err(EcKeyIncorrectError)
let length = int(pubkey.key.qlen)
p.write(Asn1Tag.BitString, pubkey.buffer.toOpenArray(offset, offset + length - 1))
p.write(Asn1Tag.BitString,
pubkey.buffer.toOpenArray(offset, offset + length - 1))
p.finish()
b.write(p)
b.finish()
@@ -454,9 +467,9 @@ proc getBytes*(seckey: EcPrivateKey): EcResult[seq[byte]] =
return err(EcKeyIncorrectError)
if seckey.key.curve in EcSupportedCurvesCint:
var res = newSeq[byte]()
let length = ?seckey.toBytes(res)
let length = ? seckey.toBytes(res)
res.setLen(length)
discard ?seckey.toBytes(res)
discard ? seckey.toBytes(res)
ok(res)
else:
err(EcKeyIncorrectError)
@@ -467,9 +480,9 @@ proc getBytes*(pubkey: EcPublicKey): EcResult[seq[byte]] =
return err(EcKeyIncorrectError)
if pubkey.key.curve in EcSupportedCurvesCint:
var res = newSeq[byte]()
let length = ?pubkey.toBytes(res)
let length = ? pubkey.toBytes(res)
res.setLen(length)
discard ?pubkey.toBytes(res)
discard ? pubkey.toBytes(res)
ok(res)
else:
err(EcKeyIncorrectError)
@@ -479,9 +492,9 @@ proc getBytes*(sig: EcSignature): EcResult[seq[byte]] =
if isNil(sig):
return err(EcSignatureError)
var res = newSeq[byte]()
let length = ?sig.toBytes(res)
let length = ? sig.toBytes(res)
res.setLen(length)
discard ?sig.toBytes(res)
discard ? sig.toBytes(res)
ok(res)
proc getRawBytes*(seckey: EcPrivateKey): EcResult[seq[byte]] =
@@ -490,9 +503,9 @@ proc getRawBytes*(seckey: EcPrivateKey): EcResult[seq[byte]] =
return err(EcKeyIncorrectError)
if seckey.key.curve in EcSupportedCurvesCint:
var res = newSeq[byte]()
let length = ?seckey.toRawBytes(res)
let length = ? seckey.toRawBytes(res)
res.setLen(length)
discard ?seckey.toRawBytes(res)
discard ? seckey.toRawBytes(res)
ok(res)
else:
err(EcKeyIncorrectError)
@@ -503,9 +516,9 @@ proc getRawBytes*(pubkey: EcPublicKey): EcResult[seq[byte]] =
return err(EcKeyIncorrectError)
if pubkey.key.curve in EcSupportedCurvesCint:
var res = newSeq[byte]()
let length = ?pubkey.toRawBytes(res)
let length = ? pubkey.toRawBytes(res)
res.setLen(length)
discard ?pubkey.toRawBytes(res)
discard ? pubkey.toRawBytes(res)
return ok(res)
else:
return err(EcKeyIncorrectError)
@@ -515,9 +528,9 @@ proc getRawBytes*(sig: EcSignature): EcResult[seq[byte]] =
if isNil(sig):
return err(EcSignatureError)
var res = newSeq[byte]()
let length = ?sig.toBytes(res)
let length = ? sig.toBytes(res)
res.setLen(length)
discard ?sig.toBytes(res)
discard ? sig.toBytes(res)
ok(res)
proc `==`*(pubkey1, pubkey2: EcPublicKey): bool =
@@ -537,10 +550,8 @@ proc `==`*(pubkey1, pubkey2: EcPublicKey): bool =
let op2 = pubkey2.getOffset()
if op1 == -1 or op2 == -1:
return false
return CT.isEqual(
pubkey1.buffer.toOpenArray(op1, pubkey1.key.qlen - 1),
pubkey2.buffer.toOpenArray(op2, pubkey2.key.qlen - 1),
)
return CT.isEqual(pubkey1.buffer.toOpenArray(op1, pubkey1.key.qlen - 1),
pubkey2.buffer.toOpenArray(op2, pubkey2.key.qlen - 1))
proc `==`*(seckey1, seckey2: EcPrivateKey): bool =
## Returns ``true`` if both keys ``seckey1`` and ``seckey2`` are equal.
@@ -559,10 +570,8 @@ proc `==`*(seckey1, seckey2: EcPrivateKey): bool =
let op2 = seckey2.getOffset()
if op1 == -1 or op2 == -1:
return false
return CT.isEqual(
seckey1.buffer.toOpenArray(op1, seckey1.key.xlen - 1),
seckey2.buffer.toOpenArray(op2, seckey2.key.xlen - 1),
)
return CT.isEqual(seckey1.buffer.toOpenArray(op1, seckey1.key.xlen - 1),
seckey2.buffer.toOpenArray(op2, seckey2.key.xlen - 1))
proc `==`*(a, b: EcSignature): bool =
## Return ``true`` if both signatures ``sig1`` and ``sig2`` are equal.
@@ -596,36 +605,36 @@ proc init*(key: var EcPrivateKey, data: openArray[byte]): Result[void, Asn1Error
var ab = Asn1Buffer.init(data)
field = ?ab.read()
field = ? ab.read()
if field.kind != Asn1Tag.Sequence:
return err(Asn1Error.Incorrect)
var ib = field.getBuffer()
field = ?ib.read()
field = ? ib.read()
if field.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
if field.vint != 1'u64:
return err(Asn1Error.Incorrect)
raw = ?ib.read()
raw = ? ib.read()
if raw.kind != Asn1Tag.OctetString:
return err(Asn1Error.Incorrect)
oid = ?ib.read()
oid = ? ib.read()
if oid.kind != Asn1Tag.Oid:
return err(Asn1Error.Incorrect)
if oid == Asn1OidSecp256r1:
curve = safeConvert[cint](Secp256r1)
curve = cast[cint](Secp256r1)
elif oid == Asn1OidSecp384r1:
curve = safeConvert[cint](Secp384r1)
curve = cast[cint](Secp384r1)
elif oid == Asn1OidSecp521r1:
curve = safeConvert[cint](Secp521r1)
curve = cast[cint](Secp521r1)
else:
return err(Asn1Error.Incorrect)
@@ -649,19 +658,19 @@ proc init*(pubkey: var EcPublicKey, data: openArray[byte]): Result[void, Asn1Err
var ab = Asn1Buffer.init(data)
field = ?ab.read()
field = ? ab.read()
if field.kind != Asn1Tag.Sequence:
return err(Asn1Error.Incorrect)
var ib = field.getBuffer()
field = ?ib.read()
field = ? ib.read()
if field.kind != Asn1Tag.Sequence:
return err(Asn1Error.Incorrect)
var ob = field.getBuffer()
oid = ?ob.read()
oid = ? ob.read()
if oid.kind != Asn1Tag.Oid:
return err(Asn1Error.Incorrect)
@@ -669,21 +678,21 @@ proc init*(pubkey: var EcPublicKey, data: openArray[byte]): Result[void, Asn1Err
if oid != Asn1OidEcPublicKey:
return err(Asn1Error.Incorrect)
oid = ?ob.read()
oid = ? ob.read()
if oid.kind != Asn1Tag.Oid:
return err(Asn1Error.Incorrect)
if oid == Asn1OidSecp256r1:
curve = safeConvert[cint](Secp256r1)
curve = cast[cint](Secp256r1)
elif oid == Asn1OidSecp384r1:
curve = safeConvert[cint](Secp384r1)
curve = cast[cint](Secp384r1)
elif oid == Asn1OidSecp521r1:
curve = safeConvert[cint](Secp521r1)
curve = cast[cint](Secp521r1)
else:
return err(Asn1Error.Incorrect)
raw = ?ib.read()
raw = ? ib.read()
if raw.kind != Asn1Tag.BitString:
return err(Asn1Error.Incorrect)
@@ -709,14 +718,16 @@ proc init*(sig: var EcSignature, data: openArray[byte]): Result[void, Asn1Error]
else:
err(Asn1Error.Incorrect)
proc init*[T: EcPKI](sospk: var T, data: string): Result[void, Asn1Error] {.inline.} =
proc init*[T: EcPKI](sospk: var T,
data: string): Result[void, Asn1Error] {.inline.} =
## Initialize EC `private key`, `public key` or `signature` ``sospk`` from
## ASN.1 DER hexadecimal string representation ``data``.
##
## Procedure returns ``Asn1Status``.
sospk.init(ncrutils.fromHex(data))
proc init*(t: typedesc[EcPrivateKey], data: openArray[byte]): EcResult[EcPrivateKey] =
proc init*(t: typedesc[EcPrivateKey],
data: openArray[byte]): EcResult[EcPrivateKey] =
## Initialize EC private key from ASN.1 DER binary representation ``data`` and
## return constructed object.
var key: EcPrivateKey
@@ -726,7 +737,8 @@ proc init*(t: typedesc[EcPrivateKey], data: openArray[byte]): EcResult[EcPrivate
else:
ok(key)
proc init*(t: typedesc[EcPublicKey], data: openArray[byte]): EcResult[EcPublicKey] =
proc init*(t: typedesc[EcPublicKey],
data: openArray[byte]): EcResult[EcPublicKey] =
## Initialize EC public key from ASN.1 DER binary representation ``data`` and
## return constructed object.
var key: EcPublicKey
@@ -736,7 +748,8 @@ proc init*(t: typedesc[EcPublicKey], data: openArray[byte]): EcResult[EcPublicKe
else:
ok(key)
proc init*(t: typedesc[EcSignature], data: openArray[byte]): EcResult[EcSignature] =
proc init*(t: typedesc[EcSignature],
data: openArray[byte]): EcResult[EcSignature] =
## Initialize EC signature from raw binary representation ``data`` and
## return constructed object.
var sig: EcSignature
@@ -761,13 +774,13 @@ proc initRaw*(key: var EcPrivateKey, data: openArray[byte]): bool =
## Procedure returns ``true`` on success, ``false`` otherwise.
var curve: cint
if len(data) == SecKey256Length:
curve = safeConvert[cint](Secp256r1)
curve = cast[cint](Secp256r1)
result = true
elif len(data) == SecKey384Length:
curve = safeConvert[cint](Secp384r1)
curve = cast[cint](Secp384r1)
result = true
elif len(data) == SecKey521Length:
curve = safeConvert[cint](Secp521r1)
curve = cast[cint](Secp521r1)
result = true
if result:
result = false
@@ -792,13 +805,13 @@ proc initRaw*(pubkey: var EcPublicKey, data: openArray[byte]): bool =
if len(data) > 0:
if data[0] == 0x04'u8:
if len(data) == PubKey256Length:
curve = safeConvert[cint](Secp256r1)
curve = cast[cint](Secp256r1)
result = true
elif len(data) == PubKey384Length:
curve = safeConvert[cint](Secp384r1)
curve = cast[cint](Secp384r1)
result = true
elif len(data) == PubKey521Length:
curve = safeConvert[cint](Secp521r1)
curve = cast[cint](Secp521r1)
result = true
if result:
result = false
@@ -819,7 +832,8 @@ proc initRaw*(sig: var EcSignature, data: openArray[byte]): bool =
##
## Procedure returns ``true`` on success, ``false`` otherwise.
let length = len(data)
if (length == Sig256Length) or (length == Sig384Length) or (length == Sig521Length):
if (length == Sig256Length) or (length == Sig384Length) or
(length == Sig521Length):
result = true
if result:
sig = new EcSignature
@@ -832,9 +846,8 @@ proc initRaw*[T: EcPKI](sospk: var T, data: string): bool {.inline.} =
## Procedure returns ``true`` on success, ``false`` otherwise.
result = sospk.initRaw(ncrutils.fromHex(data))
proc initRaw*(
t: typedesc[EcPrivateKey], data: openArray[byte]
): EcResult[EcPrivateKey] =
proc initRaw*(t: typedesc[EcPrivateKey],
data: openArray[byte]): EcResult[EcPrivateKey] =
## Initialize EC private key from raw binary representation ``data`` and
## return constructed object.
var res: EcPrivateKey
@@ -843,7 +856,8 @@ proc initRaw*(
else:
ok(res)
proc initRaw*(t: typedesc[EcPublicKey], data: openArray[byte]): EcResult[EcPublicKey] =
proc initRaw*(t: typedesc[EcPublicKey],
data: openArray[byte]): EcResult[EcPublicKey] =
## Initialize EC public key from raw binary representation ``data`` and
## return constructed object.
var res: EcPublicKey
@@ -852,7 +866,8 @@ proc initRaw*(t: typedesc[EcPublicKey], data: openArray[byte]): EcResult[EcPubli
else:
ok(res)
proc initRaw*(t: typedesc[EcSignature], data: openArray[byte]): EcResult[EcSignature] =
proc initRaw*(t: typedesc[EcSignature],
data: openArray[byte]): EcResult[EcSignature] =
## Initialize EC signature from raw binary representation ``data`` and
## return constructed object.
var res: EcSignature
@@ -879,19 +894,16 @@ proc scalarMul*(pub: EcPublicKey, sec: EcPrivateKey): EcPublicKey =
let poffset = key.getOffset()
let soffset = sec.getOffset()
if poffset >= 0 and soffset >= 0:
let res = impl.mul(
addr key.buffer[poffset],
key.key.qlen,
unsafeAddr sec.buffer[soffset],
sec.key.xlen,
key.key.curve,
)
let res = impl.mul(addr key.buffer[poffset],
key.key.qlen,
unsafeAddr sec.buffer[soffset],
sec.key.xlen,
key.key.curve)
if res != 0:
result = key
proc toSecret*(
pubkey: EcPublicKey, seckey: EcPrivateKey, data: var openArray[byte]
): int =
proc toSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey,
data: var openArray[byte]): int =
## Calculate ECDHE shared secret using Go's elliptic/curve approach, using
## remote public key ``pubkey`` and local private key ``seckey`` and store
## shared secret to ``data``.
@@ -927,9 +939,8 @@ proc getSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey): seq[byte] =
result = newSeq[byte](res)
copyMem(addr result[0], addr data[0], res)
proc sign*[T: byte | char](
seckey: EcPrivateKey, message: openArray[T]
): EcResult[EcSignature] {.gcsafe.} =
proc sign*[T: byte|char](seckey: EcPrivateKey,
message: openArray[T]): EcResult[EcSignature] {.gcsafe.} =
## Get ECDSA signature of data ``message`` using private key ``seckey``.
if isNil(seckey):
return err(EcKeyIncorrectError)
@@ -946,8 +957,8 @@ proc sign*[T: byte | char](
else:
kv.update(addr hc.vtable, nil, 0)
kv.out(addr hc.vtable, addr hash[0])
let res =
ecdsaI31SignAsn1(impl, kv, addr hash[0], addr seckey.key, addr sig.buffer[0])
let res = ecdsaI31SignAsn1(impl, kv, addr hash[0], addr seckey.key,
addr sig.buffer[0])
# Clear context with initial value
kv.init(addr hc.vtable)
if res != 0:
@@ -958,9 +969,8 @@ proc sign*[T: byte | char](
else:
err(EcKeyIncorrectError)
proc verify*[T: byte | char](
sig: EcSignature, message: openArray[T], pubkey: EcPublicKey
): bool {.inline.} =
proc verify*[T: byte|char](sig: EcSignature, message: openArray[T],
pubkey: EcPublicKey): bool {.inline.} =
## Verify ECDSA signature ``sig`` using public key ``pubkey`` and data
## ``message``.
##
@@ -978,41 +988,9 @@ proc verify*[T: byte | char](
else:
kv.update(addr hc.vtable, nil, 0)
kv.out(addr hc.vtable, addr hash[0])
let res = ecdsaI31VrfyAsn1(
impl,
addr hash[0],
uint(len(hash)),
unsafeAddr pubkey.key,
addr sig.buffer[0],
uint(len(sig.buffer)),
)
let res = ecdsaI31VrfyAsn1(impl, addr hash[0], uint(len(hash)),
unsafeAddr pubkey.key,
addr sig.buffer[0], uint(len(sig.buffer)))
# Clear context with initial value
kv.init(addr hc.vtable)
result = (res == 1)
type ECDHEScheme* = EcCurveKind
proc ephemeral*(scheme: ECDHEScheme, rng: var HmacDrbgContext): EcResult[EcKeyPair] =
## Generate ephemeral keys used to perform ECDHE.
var keypair: EcKeyPair
if scheme == Secp256r1:
keypair = ?EcKeyPair.random(Secp256r1, rng)
elif scheme == Secp384r1:
keypair = ?EcKeyPair.random(Secp384r1, rng)
elif scheme == Secp521r1:
keypair = ?EcKeyPair.random(Secp521r1, rng)
ok(keypair)
proc ephemeral*(scheme: string, rng: var HmacDrbgContext): EcResult[EcKeyPair] =
## Generate ephemeral keys used to perform ECDHE using string encoding.
##
## Currently supported encoding strings are P-256, P-384, P-521, if encoding
## string is not supported P-521 key will be generated.
if scheme == "P-256":
ephemeral(Secp256r1, rng)
elif scheme == "P-384":
ephemeral(Secp384r1, rng)
elif scheme == "P-521":
ephemeral(Secp521r1, rng)
else:
ephemeral(Secp521r1, rng)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -9,48 +9,28 @@
# https://tools.ietf.org/html/rfc5869
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import nimcrypto
import bearssl/[kdf, hash]
import bearssl/[kdf, rand, hash]
type HkdfResult*[len: static int] = array[len, byte]
proc hkdf*[T: sha256, len: static int](
_: type[T],
salt, ikm, info: openArray[byte],
outputs: var openArray[HkdfResult[len]],
) =
var ctx: HkdfContext
proc hkdf*[T: sha256; len: static int](_: type[T]; salt, ikm, info: openArray[byte]; outputs: var openArray[HkdfResult[len]]) =
var
ctx: HkdfContext
hkdfInit(
ctx,
addr sha256Vtable,
if salt.len > 0:
unsafeAddr salt[0]
else:
nil
,
csize_t(salt.len),
)
ctx, addr sha256Vtable,
if salt.len > 0: unsafeAddr salt[0] else: nil, csize_t(salt.len))
hkdfInject(
ctx,
if ikm.len > 0:
unsafeAddr ikm[0]
else:
nil
,
csize_t(ikm.len),
)
ctx, if ikm.len > 0: unsafeAddr ikm[0] else: nil, csize_t(ikm.len))
hkdfFlip(ctx)
for i in 0 .. outputs.high:
for i in 0..outputs.high:
discard hkdfProduce(
ctx,
if info.len > 0:
unsafeAddr info[0]
else:
nil
,
csize_t(info.len),
addr outputs[i][0],
csize_t(outputs[i].len),
)
if info.len > 0: unsafeAddr info[0]
else: nil, csize_t(info.len),
addr outputs[i][0], csize_t(outputs[i].len))

View File

@@ -9,44 +9,47 @@
## This module implements minimal ASN.1 encoding/decoding primitives.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import stew/[endians2, results, ctops]
export results
# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
import ../utility
type
Asn1Error* {.pure.} = enum
Overflow
Incomplete
Indefinite
Incorrect
NoSupport
Overflow,
Incomplete,
Indefinite,
Incorrect,
NoSupport,
Overrun
Asn1Result*[T] = Result[T, Asn1Error]
Asn1Class* {.pure.} = enum
Universal = 0x00
Universal = 0x00,
Application = 0x01
ContextSpecific = 0x02
Private = 0x03
Asn1Tag* {.pure.} = enum
## Protobuf's field types enum
NoSupport
Boolean
Integer
BitString
OctetString
Null
Oid
Sequence
NoSupport,
Boolean,
Integer,
BitString,
OctetString,
Null,
Oid,
Sequence,
Context
Asn1Buffer* = object of RootObj ## ASN.1's message representation object
Asn1Buffer* = object of RootObj
## ASN.1's message representation object
buffer*: seq[byte]
offset*: int
length*: int
@@ -72,23 +75,37 @@ type
idx*: int
const
Asn1OidSecp256r1* =
[0x2A'u8, 0x86'u8, 0x48'u8, 0xCE'u8, 0x3D'u8, 0x03'u8, 0x01'u8, 0x07'u8]
Asn1OidSecp256r1* = [
0x2A'u8, 0x86'u8, 0x48'u8, 0xCE'u8, 0x3D'u8, 0x03'u8, 0x01'u8, 0x07'u8
]
## Encoded OID for `secp256r1` curve (1.2.840.10045.3.1.7)
Asn1OidSecp384r1* = [0x2B'u8, 0x81'u8, 0x04'u8, 0x00'u8, 0x22'u8]
Asn1OidSecp384r1* = [
0x2B'u8, 0x81'u8, 0x04'u8, 0x00'u8, 0x22'u8
]
## Encoded OID for `secp384r1` curve (1.3.132.0.34)
Asn1OidSecp521r1* = [0x2B'u8, 0x81'u8, 0x04'u8, 0x00'u8, 0x23'u8]
Asn1OidSecp521r1* = [
0x2B'u8, 0x81'u8, 0x04'u8, 0x00'u8, 0x23'u8
]
## Encoded OID for `secp521r1` curve (1.3.132.0.35)
Asn1OidSecp256k1* = [0x2B'u8, 0x81'u8, 0x04'u8, 0x00'u8, 0x0A'u8]
Asn1OidSecp256k1* = [
0x2B'u8, 0x81'u8, 0x04'u8, 0x00'u8, 0x0A'u8
]
## Encoded OID for `secp256k1` curve (1.3.132.0.10)
Asn1OidEcPublicKey* = [0x2A'u8, 0x86'u8, 0x48'u8, 0xCE'u8, 0x3D'u8, 0x02'u8, 0x01'u8]
Asn1OidEcPublicKey* = [
0x2A'u8, 0x86'u8, 0x48'u8, 0xCE'u8, 0x3D'u8, 0x02'u8, 0x01'u8
]
## Encoded OID for Elliptic Curve Public Key (1.2.840.10045.2.1)
Asn1OidRsaEncryption* =
[0x2A'u8, 0x86'u8, 0x48'u8, 0x86'u8, 0xF7'u8, 0x0D'u8, 0x01'u8, 0x01'u8, 0x01'u8]
Asn1OidRsaEncryption* = [
0x2A'u8, 0x86'u8, 0x48'u8, 0x86'u8, 0xF7'u8, 0x0D'u8, 0x01'u8,
0x01'u8, 0x01'u8
]
## Encoded OID for RSA Encryption (1.2.840.113549.1.1.1)
Asn1True* = [0x01'u8, 0x01'u8, 0xFF'u8] ## Encoded boolean ``TRUE``.
Asn1False* = [0x01'u8, 0x01'u8, 0x00'u8] ## Encoded boolean ``FALSE``.
Asn1Null* = [0x05'u8, 0x00'u8] ## Encoded ``NULL`` value.
Asn1True* = [0x01'u8, 0x01'u8, 0xFF'u8]
## Encoded boolean ``TRUE``.
Asn1False* = [0x01'u8, 0x01'u8, 0x00'u8]
## Encoded boolean ``FALSE``.
Asn1Null* = [0x05'u8, 0x00'u8]
## Encoded ``NULL`` value.
template toOpenArray*(ab: Asn1Buffer): untyped =
toOpenArray(ab.buffer, ab.offset, ab.buffer.high)
@@ -102,10 +119,10 @@ template toOpenArray*(af: Asn1Field): untyped =
template isEmpty*(ab: Asn1Buffer): bool =
ab.offset >= len(ab.buffer)
template isEnough*(ab: Asn1Buffer, length: int64): bool =
template isEnough*(ab: Asn1Buffer, length: int): bool =
len(ab.buffer) >= ab.offset + length
proc len*[T: Asn1Buffer | Asn1Composite](abc: T): int {.inline.} =
proc len*[T: Asn1Buffer|Asn1Composite](abc: T): int {.inline.} =
len(abc.buffer) - abc.offset
proc len*(field: Asn1Field): int {.inline.} =
@@ -114,22 +131,31 @@ proc len*(field: Asn1Field): int {.inline.} =
template getPtr*(field: untyped): pointer =
cast[pointer](unsafeAddr field.buffer[field.offset])
proc extend*[T: Asn1Buffer | Asn1Composite](abc: var T, length: int) {.inline.} =
proc extend*[T: Asn1Buffer|Asn1Composite](abc: var T, length: int) {.inline.} =
## Extend buffer or composite's internal buffer by ``length`` octets.
abc.buffer.setLen(len(abc.buffer) + length)
proc code*(tag: Asn1Tag): byte {.inline.} =
## Converts Nim ``tag`` enum to ASN.1 tag code.
case tag
of Asn1Tag.NoSupport: 0x00'u8
of Asn1Tag.Boolean: 0x01'u8
of Asn1Tag.Integer: 0x02'u8
of Asn1Tag.BitString: 0x03'u8
of Asn1Tag.OctetString: 0x04'u8
of Asn1Tag.Null: 0x05'u8
of Asn1Tag.Oid: 0x06'u8
of Asn1Tag.Sequence: 0x30'u8
of Asn1Tag.Context: 0xA0'u8
case tag:
of Asn1Tag.NoSupport:
0x00'u8
of Asn1Tag.Boolean:
0x01'u8
of Asn1Tag.Integer:
0x02'u8
of Asn1Tag.BitString:
0x03'u8
of Asn1Tag.OctetString:
0x04'u8
of Asn1Tag.Null:
0x05'u8
of Asn1Tag.Oid:
0x06'u8
of Asn1Tag.Sequence:
0x30'u8
of Asn1Tag.Context:
0xA0'u8
proc asn1EncodeLength*(dest: var openArray[byte], length: uint64): int =
## Encode ASN.1 DER length part of TLV triple and return number of bytes
@@ -158,7 +184,8 @@ proc asn1EncodeLength*(dest: var openArray[byte], length: uint64): int =
# then 9, so it is safe to convert it to `int`.
int(res)
proc asn1EncodeInteger*(dest: var openArray[byte], value: openArray[byte]): int =
proc asn1EncodeInteger*(dest: var openArray[byte],
value: openArray[byte]): int =
## Encode big-endian binary representation of integer as ASN.1 DER `INTEGER`
## and return number of bytes (octets) used.
##
@@ -168,16 +195,17 @@ proc asn1EncodeInteger*(dest: var openArray[byte], value: openArray[byte]): int
var buffer: array[16, byte]
var lenlen = 0
let offset = block:
var o = 0
for i in 0 ..< len(value):
if value[o] != 0x00:
break
inc(o)
if o < len(value):
o
else:
o - 1
let offset =
block:
var o = 0
for i in 0 ..< len(value):
if value[o] != 0x00:
break
inc(o)
if o < len(value):
o
else:
o - 1
let destlen =
if len(value) > 0:
@@ -199,10 +227,12 @@ proc asn1EncodeInteger*(dest: var openArray[byte], value: openArray[byte]): int
if value[offset] >= 0x80'u8:
dest[1 + lenlen] = 0x00'u8
shift = 2
copyMem(addr dest[shift + lenlen], unsafeAddr value[offset], len(value) - offset)
copyMem(addr dest[shift + lenlen], unsafeAddr value[offset],
len(value) - offset)
destlen
proc asn1EncodeInteger*[T: SomeUnsignedInt](dest: var openArray[byte], value: T): int =
proc asn1EncodeInteger*[T: SomeUnsignedInt](dest: var openArray[byte],
value: T): int =
## Encode Nim's unsigned integer as ASN.1 DER `INTEGER` and return number of
## bytes (octets) used.
##
@@ -237,7 +267,8 @@ proc asn1EncodeNull*(dest: var openArray[byte]): int =
dest[1] = 0x00'u8
res
proc asn1EncodeOctetString*(dest: var openArray[byte], value: openArray[byte]): int =
proc asn1EncodeOctetString*(dest: var openArray[byte],
value: openArray[byte]): int =
## Encode array of bytes as ASN.1 DER `OCTET STRING` and return number of
## bytes (octets) used.
##
@@ -254,9 +285,8 @@ proc asn1EncodeOctetString*(dest: var openArray[byte], value: openArray[byte]):
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc asn1EncodeBitString*(
dest: var openArray[byte], value: openArray[byte], bits = 0
): int =
proc asn1EncodeBitString*(dest: var openArray[byte],
value: openArray[byte], bits = 0): int =
## Encode array of bytes as ASN.1 DER `BIT STRING` and return number of bytes
## (octets) used.
##
@@ -277,7 +307,7 @@ proc asn1EncodeBitString*(
let bytelen = (bitlen + 7) shr 3
# Number of unused bits
let unused = (8 - (bitlen and 7)) and 7
let mask = not ((1'u8 shl unused) - 1'u8)
let mask = not((1'u8 shl unused) - 1'u8)
var lenlen = asn1EncodeLength(buffer, uint64(bytelen + 1))
let res = 1 + lenlen + 1 + len(value)
if len(dest) >= res:
@@ -291,7 +321,8 @@ proc asn1EncodeBitString*(
dest[2 + lenlen + bytelen - 1] = lastbyte and mask
res
proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openArray[byte], value: T): int =
proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openArray[byte],
value: T): int =
var v = value
if value <= cast[T](0x7F):
if len(dest) >= 1:
@@ -313,6 +344,32 @@ proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openArray[byte], value: T): int
dest[k - 1] = dest[k - 1] and 0x7F'u8
res
proc asn1EncodeOid*(dest: var openArray[byte], value: openArray[int]): int =
## Encode array of integers ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and
## return number of bytes (octets) used.
##
## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned.
var buffer: array[16, byte]
var res = 1
var oidlen = 1
for i in 2..<len(value):
oidlen += asn1EncodeTag(buffer, cast[uint64](value[i]))
res += asn1EncodeLength(buffer, uint64(oidlen))
res += oidlen
if len(dest) >= res:
let last = dest.high
var offset = 1
dest[0] = Asn1Tag.Oid.code()
offset += asn1EncodeLength(dest.toOpenArray(offset, last), uint64(oidlen))
dest[offset] = cast[byte](value[0] * 40 + value[1])
offset += 1
for i in 2..<len(value):
offset += asn1EncodeTag(dest.toOpenArray(offset, last),
cast[uint64](value[i]))
res
proc asn1EncodeOid*(dest: var openArray[byte], value: openArray[byte]): int =
## Encode array of bytes ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and return
## number of bytes (octets) used.
@@ -332,7 +389,8 @@ proc asn1EncodeOid*(dest: var openArray[byte], value: openArray[byte]): int =
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc asn1EncodeSequence*(dest: var openArray[byte], value: openArray[byte]): int =
proc asn1EncodeSequence*(dest: var openArray[byte],
value: openArray[byte]): int =
## Encode ``value`` as ASN.1 DER `SEQUENCE` and return number of bytes
## (octets) used.
##
@@ -348,7 +406,8 @@ proc asn1EncodeSequence*(dest: var openArray[byte], value: openArray[byte]): int
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc asn1EncodeComposite*(dest: var openArray[byte], value: Asn1Composite): int =
proc asn1EncodeComposite*(dest: var openArray[byte],
value: Asn1Composite): int =
## Encode composite value and return number of bytes (octets) used.
##
## If length of ``dest`` is less then number of required bytes to encode
@@ -360,12 +419,12 @@ proc asn1EncodeComposite*(dest: var openArray[byte], value: Asn1Composite): int
if len(dest) >= res:
dest[0] = value.tag.code()
copyMem(addr dest[1], addr buffer[0], lenlen)
copyMem(addr dest[1 + lenlen], unsafeAddr value.buffer[0], len(value.buffer))
copyMem(addr dest[1 + lenlen], unsafeAddr value.buffer[0],
len(value.buffer))
res
proc asn1EncodeContextTag*(
dest: var openArray[byte], value: openArray[byte], tag: int
): int =
proc asn1EncodeContextTag*(dest: var openArray[byte], value: openArray[byte],
tag: int): int =
## Encode ASN.1 DER `CONTEXT SPECIFIC TAG` ``tag`` for value ``value`` and
## return number of bytes (octets) used.
##
@@ -384,29 +443,26 @@ proc asn1EncodeContextTag*(
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc getLength(ab: var Asn1Buffer): Asn1Result[int] =
proc getLength(ab: var Asn1Buffer): Asn1Result[uint64] =
## Decode length part of ASN.1 TLV triplet.
if not ab.isEmpty():
let b = ab.buffer[ab.offset]
if (b and 0x80'u8) == 0x00'u8:
let length = safeConvert[int](b)
let length = cast[uint64](b)
ab.offset += 1
return ok(length)
if b == 0x80'u8:
return err(Asn1Error.Indefinite)
if b == 0xFF'u8:
return err(Asn1Error.Incorrect)
let octets = safeConvert[int](b and 0x7F'u8)
if octets > 8:
let octets = cast[uint64](b and 0x7F'u8)
if octets > 8'u64:
return err(Asn1Error.Overflow)
if ab.isEnough(octets):
var lengthU: uint64 = 0
for i in 0 ..< octets:
lengthU = (lengthU shl 8) or safeConvert[uint64](ab.buffer[ab.offset + i + 1])
if lengthU > uint64(int64.high):
return err(Asn1Error.Overflow)
let length = int(lengthU)
ab.offset = ab.offset + octets + 1
if ab.isEnough(int(octets)):
var length: uint64 = 0
for i in 0..<int(octets):
length = (length shl 8) or cast[uint64](ab.buffer[ab.offset + i + 1])
ab.offset = ab.offset + int(octets) + 1
return ok(length)
else:
return err(Asn1Error.Incomplete)
@@ -418,8 +474,8 @@ proc getTag(ab: var Asn1Buffer, tag: var int): Asn1Result[Asn1Class] =
if not ab.isEmpty():
let
b = ab.buffer[ab.offset]
c = safeConvert[int]((b and 0xC0'u8) shr 6)
tag = safeConvert[int](b and 0x3F)
c = int((b and 0xC0'u8) shr 6)
tag = int(b and 0x3F)
ab.offset += 1
if c >= 0 and c < 4:
ok(cast[Asn1Class](c))
@@ -433,14 +489,14 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
var
field: Asn1Field
tag, ttag, offset: int
length, tlength: int
length, tlength: uint64
aclass: Asn1Class
inclass: bool
inclass = false
while true:
offset = ab.offset
aclass = ?ab.getTag(tag)
aclass = ? ab.getTag(tag)
case aclass
of Asn1Class.ContextSpecific:
@@ -449,9 +505,9 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
else:
inclass = true
ttag = tag
tlength = ?ab.getLength()
tlength = ? ab.getLength()
of Asn1Class.Universal:
length = ?ab.getLength()
length = ? ab.getLength()
if inclass:
if length >= tlength:
@@ -463,35 +519,31 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
if length != 1:
return err(Asn1Error.Incorrect)
if not ab.isEnough(length):
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete)
let b = ab.buffer[ab.offset]
if b != 0xFF'u8 and b != 0x00'u8:
return err(Asn1Error.Incorrect)
return err(Asn1Error.Incorrect)
field = Asn1Field(
kind: Asn1Tag.Boolean,
klass: aclass,
index: ttag,
offset: ab.offset,
length: 1,
buffer: ab.buffer,
)
field = Asn1Field(kind: Asn1Tag.Boolean, klass: aclass,
index: ttag, offset: int(ab.offset),
length: 1, buffer: ab.buffer)
field.vbool = (b == 0xFF'u8)
ab.offset += 1
return ok(field)
of Asn1Tag.Integer.code():
# INTEGER
if length == 0:
return err(Asn1Error.Incorrect)
if not ab.isEnough(length):
return err(Asn1Error.Incomplete)
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete)
# Count number of leading zeroes
var zc = 0
while (zc < length) and (ab.buffer[ab.offset + zc] == 0x00'u8):
while (zc < int(length)) and (ab.buffer[ab.offset + zc] == 0x00'u8):
inc(zc)
if zc > 1:
@@ -499,87 +551,68 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
if zc == 0:
# Negative or Positive integer
field = Asn1Field(
kind: Asn1Tag.Integer,
klass: aclass,
index: ttag,
offset: ab.offset,
length: length,
buffer: ab.buffer,
)
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
if (ab.buffer[ab.offset] and 0x80'u8) == 0x80'u8:
# Negative integer
if length <= 8:
# We need this transformation because our field.vint is uint64.
for i in 0 ..< 8:
if i < 8 - length:
if i < 8 - int(length):
field.vint = (field.vint shl 8) or 0xFF'u64
else:
let offset = ab.offset + i - (8 - length)
field.vint =
(field.vint shl 8) or safeConvert[uint64](ab.buffer[offset])
let offset = ab.offset + i - (8 - int(length))
field.vint = (field.vint shl 8) or uint64(ab.buffer[offset])
else:
# Positive integer
if length <= 8:
for i in 0 ..< length:
field.vint =
(field.vint shl 8) or safeConvert[uint64](ab.buffer[ab.offset + i])
ab.offset += length
for i in 0 ..< int(length):
field.vint = (field.vint shl 8) or
uint64(ab.buffer[ab.offset + i])
ab.offset += int(length)
return ok(field)
else:
if length == 1:
# Zero value integer
field = Asn1Field(
kind: Asn1Tag.Integer,
klass: aclass,
index: ttag,
offset: ab.offset,
length: length,
vint: 0'u64,
buffer: ab.buffer,
)
ab.offset += length
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), vint: 0'u64,
buffer: ab.buffer)
ab.offset += int(length)
return ok(field)
else:
# Positive integer with leading zero
field = Asn1Field(
kind: Asn1Tag.Integer,
klass: aclass,
index: ttag,
offset: ab.offset + 1,
length: length - 1,
buffer: ab.buffer,
)
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset) + 1,
length: int(length) - 1, buffer: ab.buffer)
if length <= 9:
for i in 1 ..< length:
field.vint =
(field.vint shl 8) or safeConvert[uint64](ab.buffer[ab.offset + i])
ab.offset += length
for i in 1 ..< int(length):
field.vint = (field.vint shl 8) or
uint64(ab.buffer[ab.offset + i])
ab.offset += int(length)
return ok(field)
of Asn1Tag.BitString.code():
# BIT STRING
if length == 0:
# BIT STRING should include `unused` bits field, so length should be
# bigger then 1.
return err(Asn1Error.Incorrect)
elif length == 1:
if ab.buffer[ab.offset] != 0x00'u8:
return err(Asn1Error.Incorrect)
else:
# Zero-length BIT STRING.
field = Asn1Field(
kind: Asn1Tag.BitString,
klass: aclass,
index: ttag,
offset: ab.offset + 1,
length: 0,
ubits: 0,
buffer: ab.buffer,
)
ab.offset += length
field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass,
index: ttag, offset: int(ab.offset + 1),
length: 0, ubits: 0, buffer: ab.buffer)
ab.offset += int(length)
return ok(field)
else:
if not ab.isEnough(length):
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete)
let unused = ab.buffer[ab.offset]
@@ -587,82 +620,61 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
# Number of unused bits should not be bigger then `7`.
return err(Asn1Error.Incorrect)
let mask = (1'u8 shl safeConvert[int](unused)) - 1'u8
if (ab.buffer[ab.offset + length - 1] and mask) != 0x00'u8:
let mask = (1'u8 shl int(unused)) - 1'u8
if (ab.buffer[ab.offset + int(length) - 1] and mask) != 0x00'u8:
## All unused bits should be set to `0`.
return err(Asn1Error.Incorrect)
field = Asn1Field(
kind: Asn1Tag.BitString,
klass: aclass,
index: ttag,
offset: ab.offset + 1,
length: length - 1,
ubits: safeConvert[int](unused),
buffer: ab.buffer,
)
ab.offset += length
field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass,
index: ttag, offset: int(ab.offset + 1),
length: int(length - 1), ubits: int(unused),
buffer: ab.buffer)
ab.offset += int(length)
return ok(field)
of Asn1Tag.OctetString.code():
# OCTET STRING
if not ab.isEnough(length):
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete)
field = Asn1Field(
kind: Asn1Tag.OctetString,
klass: aclass,
index: ttag,
offset: ab.offset,
length: length,
buffer: ab.buffer,
)
ab.offset += length
field = Asn1Field(kind: Asn1Tag.OctetString, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
ab.offset += int(length)
return ok(field)
of Asn1Tag.Null.code():
# NULL
if length != 0:
return err(Asn1Error.Incorrect)
field = Asn1Field(
kind: Asn1Tag.Null,
klass: aclass,
index: ttag,
offset: ab.offset,
length: 0,
buffer: ab.buffer,
)
ab.offset += length
field = Asn1Field(kind: Asn1Tag.Null, klass: aclass, index: ttag,
offset: int(ab.offset), length: 0, buffer: ab.buffer)
ab.offset += int(length)
return ok(field)
of Asn1Tag.Oid.code():
# OID
if not ab.isEnough(length):
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete)
field = Asn1Field(
kind: Asn1Tag.Oid,
klass: aclass,
index: ttag,
offset: ab.offset,
length: length,
buffer: ab.buffer,
)
ab.offset += length
field = Asn1Field(kind: Asn1Tag.Oid, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
ab.offset += int(length)
return ok(field)
of Asn1Tag.Sequence.code():
# SEQUENCE
if not ab.isEnough(length):
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete)
field = Asn1Field(
kind: Asn1Tag.Sequence,
klass: aclass,
index: ttag,
offset: ab.offset,
length: length,
buffer: ab.buffer,
)
ab.offset += length
field = Asn1Field(kind: Asn1Tag.Sequence, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), buffer: ab.buffer)
ab.offset += int(length)
return ok(field)
else:
return err(Asn1Error.NoSupport)
@@ -685,9 +697,9 @@ proc `==`*(field: Asn1Field, data: openArray[byte]): bool =
if length > 0:
if field.length == len(data):
CT.isEqual(
field.buffer.toOpenArray(field.offset, field.offset + field.length - 1),
data.toOpenArray(0, field.length - 1),
)
field.buffer.toOpenArray(field.offset,
field.offset + field.length - 1),
data.toOpenArray(0, field.length - 1))
else:
false
else:
@@ -765,14 +777,13 @@ proc `$`*(field: Asn1Field): string =
res.add(ncrutils.toHex(field.toOpenArray()))
res
proc write*[T: Asn1Buffer | Asn1Composite](abc: var T, tag: Asn1Tag) =
proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, tag: Asn1Tag) =
## Write empty value to buffer or composite with ``tag``.
##
## This procedure must be used to write `NULL`, `0` or empty `BIT STRING`,
## `OCTET STRING` types.
doAssert(
tag in {Asn1Tag.Null, Asn1Tag.Integer, Asn1Tag.BitString, Asn1Tag.OctetString}
)
doAssert(tag in {Asn1Tag.Null, Asn1Tag.Integer, Asn1Tag.BitString,
Asn1Tag.OctetString})
var length: int
if tag == Asn1Tag.Null:
length = asn1EncodeNull(abc.toOpenArray())
@@ -794,23 +805,22 @@ proc write*[T: Asn1Buffer | Asn1Composite](abc: var T, tag: Asn1Tag) =
discard asn1EncodeOctetString(abc.toOpenArray(), tmp.toOpenArray(0, -1))
abc.offset += length
proc write*[T: Asn1Buffer | Asn1Composite](abc: var T, value: uint64) =
proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, value: uint64) =
## Write uint64 ``value`` to buffer or composite as ASN.1 `INTEGER`.
let length = asn1EncodeInteger(abc.toOpenArray(), value)
abc.extend(length)
discard asn1EncodeInteger(abc.toOpenArray(), value)
abc.offset += length
proc write*[T: Asn1Buffer | Asn1Composite](abc: var T, value: bool) =
proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, value: bool) =
## Write bool ``value`` to buffer or composite as ASN.1 `BOOLEAN`.
let length = asn1EncodeBoolean(abc.toOpenArray(), value)
abc.extend(length)
discard asn1EncodeBoolean(abc.toOpenArray(), value)
abc.offset += length
proc write*[T: Asn1Buffer | Asn1Composite](
abc: var T, tag: Asn1Tag, value: openArray[byte], bits = 0
) =
proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, tag: Asn1Tag,
value: openArray[byte], bits = 0) =
## Write array ``value`` using ``tag``.
##
## This procedure is used to write ASN.1 `INTEGER`, `OCTET STRING`,
@@ -818,9 +828,8 @@ proc write*[T: Asn1Buffer | Asn1Composite](
##
## For `BIT STRING` you can use ``bits`` argument to specify number of used
## bits.
doAssert(
tag in {Asn1Tag.Integer, Asn1Tag.OctetString, Asn1Tag.BitString, Asn1Tag.Oid}
)
doAssert(tag in {Asn1Tag.Integer, Asn1Tag.OctetString, Asn1Tag.BitString,
Asn1Tag.Oid})
var length: int
if tag == Asn1Tag.Integer:
length = asn1EncodeInteger(abc.toOpenArray(), value)
@@ -840,7 +849,7 @@ proc write*[T: Asn1Buffer | Asn1Composite](
discard asn1EncodeOid(abc.toOpenArray(), value)
abc.offset += length
proc write*[T: Asn1Buffer | Asn1Composite](abc: var T, value: Asn1Composite) =
proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, value: Asn1Composite) =
doAssert(len(value) > 0, "Composite value not finished")
var length: int
if value.tag == Asn1Tag.Sequence:
@@ -857,6 +866,6 @@ proc write*[T: Asn1Buffer | Asn1Composite](abc: var T, value: Asn1Composite) =
discard asn1EncodeContextTag(abc.toOpenArray(), value.buffer, value.idx)
abc.offset += length
proc finish*[T: Asn1Buffer | Asn1Composite](abc: var T) {.inline.} =
proc finish*[T: Asn1Buffer|Asn1Composite](abc: var T) {.inline.} =
## Finishes buffer or composite and prepares it for writing.
abc.offset = 0

View File

@@ -13,7 +13,10 @@
## BearSSL library <https://bearssl.org/>
## Copyright(C) 2018 Thomas Pornin <pornin@bolet.org>.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import bearssl/[rsa, rand, hash]
import minasn1
@@ -30,17 +33,32 @@ const
MinKeySize* = 2048
## Minimal allowed RSA key size in bits.
## https://github.com/libp2p/go-libp2p-core/blob/master/crypto/rsa_common.go#L13
DefaultKeySize* = 3072 ## Default RSA key size in bits.
DefaultKeySize* = 3072
## Default RSA key size in bits.
RsaOidSha1* = [byte 0x05, 0x2B, 0x0E, 0x03, 0x02, 0x1A]
RsaOidSha1* = [
byte 0x05, 0x2B, 0x0E, 0x03, 0x02, 0x1A
]
## RSA PKCS#1.5 SHA-1 hash object identifier.
RsaOidSha224* = [byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04]
RsaOidSha224* = [
byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04,
0x02, 0x04
]
## RSA PKCS#1.5 SHA-224 hash object identifier.
RsaOidSha256* = [byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01]
RsaOidSha256* = [
byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04,
0x02, 0x01
]
## RSA PKCS#1.5 SHA-256 hash object identifier.
RsaOidSha384* = [byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02]
RsaOidSha384* = [
byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04,
0x02, 0x02
]
## RSA PKCS#1.5 SHA-384 hash object identifier.
RsaOidSha512* = [byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03]
RsaOidSha512* = [
byte 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04,
0x02, 0x03
]
## RSA PKCS#1.5 SHA-512 hash object identifier.
type
@@ -64,9 +82,9 @@ type
RsaKP* = RsaPrivateKey | RsaKeyPair
RsaError* = enum
RsaGenError
RsaKeyIncorrectError
RsaSignatureError
RsaGenError,
RsaKeyIncorrectError,
RsaSignatureError,
RsaLowSecurityError
RsaResult*[T] = Result[T, RsaError]
@@ -94,18 +112,15 @@ template getArray*(bs, os, ls: untyped): untyped =
template trimZeroes(b: seq[byte], pt, ptlen: untyped) =
var length = ptlen
for i in 0 ..< length:
for i in 0..<length:
if pt[] != byte(0x00):
break
pt = cast[ptr byte](cast[uint](pt) + 1)
ptlen -= 1
proc random*[T: RsaKP](
t: typedesc[T],
rng: var HmacDrbgContext,
bits = DefaultKeySize,
pubexp = DefaultPublicExponent,
): RsaResult[T] =
proc random*[T: RsaKP](t: typedesc[T], rng: var HmacDrbgContext,
bits = DefaultKeySize,
pubexp = DefaultPublicExponent): RsaResult[T] =
## Generate new random RSA private key using BearSSL's HMAC-SHA256-DRBG
## algorithm.
##
@@ -127,15 +142,10 @@ proc random*[T: RsaKP](
var keygen = rsaKeygenGetDefault()
if keygen(
addr rng.vtable,
addr res.seck,
addr res.buffer[sko],
addr res.pubk,
addr res.buffer[pko],
cuint(bits),
pubexp,
) == 0:
if keygen(addr rng.vtable,
addr res.seck, addr res.buffer[sko],
addr res.pubk, addr res.buffer[pko],
cuint(bits), pubexp) == 0:
return err(RsaGenError)
let
@@ -163,10 +173,9 @@ proc copy*[T: RsaPKI](key: T): T =
doAssert(not isNil(key))
when T is RsaPrivateKey:
if len(key.buffer) > 0:
let length =
key.seck.plen.uint + key.seck.qlen.uint + key.seck.dplen.uint +
key.seck.dqlen.uint + key.seck.iqlen.uint + key.pubk.nlen.uint +
key.pubk.elen.uint + key.pexplen.uint
let length = key.seck.plen.uint + key.seck.qlen.uint + key.seck.dplen.uint +
key.seck.dqlen.uint + key.seck.iqlen.uint + key.pubk.nlen.uint +
key.pubk.elen.uint + key.pexplen.uint
result = new RsaPrivateKey
result.buffer = newSeq[byte](length)
let po: uint = 0
@@ -229,7 +238,8 @@ proc getPublicKey*(key: RsaPrivateKey): RsaPublicKey =
result.key.n = addr result.buffer[0]
result.key.e = addr result.buffer[key.pubk.nlen]
copyMem(addr result.buffer[0], cast[pointer](key.pubk.n), key.pubk.nlen)
copyMem(addr result.buffer[key.pubk.nlen], cast[pointer](key.pubk.e), key.pubk.elen)
copyMem(addr result.buffer[key.pubk.nlen], cast[pointer](key.pubk.e),
key.pubk.elen)
result.key.nlen = key.pubk.nlen
result.key.elen = key.pubk.elen
@@ -241,7 +251,7 @@ proc pubkey*(pair: RsaKeyPair): RsaPublicKey {.inline.} =
## Get RSA public key from pair ``pair``.
result = cast[RsaPrivateKey](pair).getPublicKey()
proc clear*[T: RsaPKI | RsaKeyPair](pki: var T) =
proc clear*[T: RsaPKI|RsaKeyPair](pki: var T) =
## Wipe and clear EC private key, public key or scalar object.
doAssert(not isNil(pki))
when T is RsaPrivateKey:
@@ -285,14 +295,21 @@ proc toBytes*(key: RsaPrivateKey, data: var openArray[byte]): RsaResult[int] =
var b = Asn1Buffer.init()
var p = Asn1Composite.init(Asn1Tag.Sequence)
p.write(0'u64)
p.write(Asn1Tag.Integer, getArray(key.buffer, key.pubk.n, key.pubk.nlen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.pubk.e, key.pubk.elen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.pubk.n,
key.pubk.nlen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.pubk.e,
key.pubk.elen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.pexp, key.pexplen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.p, key.seck.plen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.q, key.seck.qlen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.dp, key.seck.dplen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.dq, key.seck.dqlen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.iq, key.seck.iqlen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.p,
key.seck.plen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.q,
key.seck.qlen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.dp,
key.seck.dplen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.dq,
key.seck.dqlen))
p.write(Asn1Tag.Integer, getArray(key.buffer, key.seck.iq,
key.seck.iqlen))
p.finish()
b.write(p)
b.finish()
@@ -357,7 +374,7 @@ proc getBytes*(key: RsaPrivateKey): RsaResult[seq[byte]] =
if isNil(key):
return err(RsaKeyIncorrectError)
var res = newSeq[byte](4096)
let length = ?key.toBytes(res)
let length = ? key.toBytes(res)
if length > 0:
res.setLen(length)
ok(res)
@@ -370,7 +387,7 @@ proc getBytes*(key: RsaPublicKey): RsaResult[seq[byte]] =
if isNil(key):
return err(RsaKeyIncorrectError)
var res = newSeq[byte](4096)
let length = ?key.toBytes(res)
let length = ? key.toBytes(res)
if length > 0:
res.setLen(length)
ok(res)
@@ -382,7 +399,7 @@ proc getBytes*(sig: RsaSignature): RsaResult[seq[byte]] =
if isNil(sig):
return err(RsaSignatureError)
var res = newSeq[byte](4096)
let length = ?sig.toBytes(res)
let length = ? sig.toBytes(res)
if length > 0:
res.setLen(length)
ok(res)
@@ -394,19 +411,20 @@ proc init*(key: var RsaPrivateKey, data: openArray[byte]): Result[void, Asn1Erro
## ``data``.
##
## Procedure returns ``Asn1Status``.
var field, rawn, rawpube, rawprie, rawp, rawq, rawdp, rawdq, rawiq: Asn1Field
var
field, rawn, rawpube, rawprie, rawp, rawq, rawdp, rawdq, rawiq: Asn1Field
# Asn1Field is not trivial so avoid too much Result
var ab = Asn1Buffer.init(data)
field = ?ab.read()
field = ? ab.read()
if field.kind != Asn1Tag.Sequence:
return err(Asn1Error.Incorrect)
var ib = field.getBuffer()
field = ?ib.read()
field = ? ib.read()
if field.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
@@ -414,48 +432,48 @@ proc init*(key: var RsaPrivateKey, data: openArray[byte]): Result[void, Asn1Erro
if field.vint != 0'u64:
return err(Asn1Error.Incorrect)
rawn = ?ib.read()
rawn = ? ib.read()
if rawn.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawpube = ?ib.read()
rawpube = ? ib.read()
if rawpube.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawprie = ?ib.read()
rawprie = ? ib.read()
if rawprie.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawp = ?ib.read()
rawp = ? ib.read()
if rawp.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawq = ?ib.read()
rawq = ? ib.read()
if rawq.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawdp = ?ib.read()
rawdp = ? ib.read()
if rawdp.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawdq = ?ib.read()
rawdq = ? ib.read()
if rawdq.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawiq = ?ib.read()
rawiq = ? ib.read()
if rawiq.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
if len(rawn) >= (MinKeySize shr 3) and len(rawp) > 0 and len(rawq) > 0 and
len(rawdp) > 0 and len(rawdq) > 0 and len(rawiq) > 0:
len(rawdp) > 0 and len(rawdq) > 0 and len(rawiq) > 0:
key = new RsaPrivateKey
key.buffer = @data
key.pubk.n = addr key.buffer[rawn.offset]
@@ -487,52 +505,52 @@ proc init*(key: var RsaPublicKey, data: openArray[byte]): Result[void, Asn1Error
var field, rawn, rawe: Asn1Field
var ab = Asn1Buffer.init(data)
field = ?ab.read()
field = ? ab.read()
if field.kind != Asn1Tag.Sequence:
return err(Asn1Error.Incorrect)
var ib = field.getBuffer()
field = ?ib.read()
field = ? ib.read()
if field.kind != Asn1Tag.Sequence:
return err(Asn1Error.Incorrect)
var ob = field.getBuffer()
field = ?ob.read()
field = ? ob.read()
if field.kind != Asn1Tag.Oid:
return err(Asn1Error.Incorrect)
elif field != Asn1OidRsaEncryption:
return err(Asn1Error.Incorrect)
field = ?ob.read()
field = ? ob.read()
if field.kind != Asn1Tag.Null:
return err(Asn1Error.Incorrect)
field = ?ib.read()
field = ? ib.read()
if field.kind != Asn1Tag.BitString:
return err(Asn1Error.Incorrect)
var vb = field.getBuffer()
field = ?vb.read()
field = ? vb.read()
if field.kind != Asn1Tag.Sequence:
return err(Asn1Error.Incorrect)
var sb = field.getBuffer()
rawn = ?sb.read()
rawn = ? sb.read()
if rawn.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
rawe = ?sb.read()
rawe = ? sb.read()
if rawe.kind != Asn1Tag.Integer:
return err(Asn1Error.Incorrect)
@@ -560,16 +578,16 @@ proc init*(sig: var RsaSignature, data: openArray[byte]): Result[void, Asn1Error
else:
err(Asn1Error.Incorrect)
proc init*[T: RsaPKI](sospk: var T, data: string): Result[void, Asn1Error] {.inline.} =
proc init*[T: RsaPKI](sospk: var T,
data: string): Result[void, Asn1Error] {.inline.} =
## Initialize EC `private key`, `public key` or `scalar` ``sospk`` from
## hexadecimal string representation ``data``.
##
## Procedure returns ``Result[void, Asn1Status]``.
sospk.init(ncrutils.fromHex(data))
proc init*(
t: typedesc[RsaPrivateKey], data: openArray[byte]
): RsaResult[RsaPrivateKey] =
proc init*(t: typedesc[RsaPrivateKey],
data: openArray[byte]): RsaResult[RsaPrivateKey] =
## Initialize RSA private key from ASN.1 DER binary representation ``data``
## and return constructed object.
var res: RsaPrivateKey
@@ -578,7 +596,8 @@ proc init*(
else:
ok(res)
proc init*(t: typedesc[RsaPublicKey], data: openArray[byte]): RsaResult[RsaPublicKey] =
proc init*(t: typedesc[RsaPublicKey],
data: openArray[byte]): RsaResult[RsaPublicKey] =
## Initialize RSA public key from ASN.1 DER binary representation ``data``
## and return constructed object.
var res: RsaPublicKey
@@ -587,7 +606,8 @@ proc init*(t: typedesc[RsaPublicKey], data: openArray[byte]): RsaResult[RsaPubli
else:
ok(res)
proc init*(t: typedesc[RsaSignature], data: openArray[byte]): RsaResult[RsaSignature] =
proc init*(t: typedesc[RsaSignature],
data: openArray[byte]): RsaResult[RsaSignature] =
## Initialize RSA signature from raw binary representation ``data`` and
## return constructed object.
var res: RsaSignature
@@ -614,11 +634,14 @@ proc `$`*(key: RsaPrivateKey): string =
result.add("\nq = ")
result.add(ncrutils.toHex(getArray(key.buffer, key.seck.q, key.seck.qlen)))
result.add("\ndp = ")
result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dp, key.seck.dplen)))
result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dp,
key.seck.dplen)))
result.add("\ndq = ")
result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dq, key.seck.dqlen)))
result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dq,
key.seck.dqlen)))
result.add("\niq = ")
result.add(ncrutils.toHex(getArray(key.buffer, key.seck.iq, key.seck.iqlen)))
result.add(ncrutils.toHex(getArray(key.buffer, key.seck.iq,
key.seck.iqlen)))
result.add("\npre = ")
result.add(ncrutils.toHex(getArray(key.buffer, key.pexp, key.pexplen)))
result.add("\nm = ")
@@ -663,38 +686,23 @@ proc `==`*(a, b: RsaPrivateKey): bool =
false
else:
if a.seck.nBitlen == b.seck.nBitlen:
if a.seck.nBitlen > 0'u:
let r1 = CT.isEqual(
getArray(a.buffer, a.seck.p, a.seck.plen),
getArray(b.buffer, b.seck.p, b.seck.plen),
)
let r2 = CT.isEqual(
getArray(a.buffer, a.seck.q, a.seck.qlen),
getArray(b.buffer, b.seck.q, b.seck.qlen),
)
let r3 = CT.isEqual(
getArray(a.buffer, a.seck.dp, a.seck.dplen),
getArray(b.buffer, b.seck.dp, b.seck.dplen),
)
let r4 = CT.isEqual(
getArray(a.buffer, a.seck.dq, a.seck.dqlen),
getArray(b.buffer, b.seck.dq, b.seck.dqlen),
)
let r5 = CT.isEqual(
getArray(a.buffer, a.seck.iq, a.seck.iqlen),
getArray(b.buffer, b.seck.iq, b.seck.iqlen),
)
let r6 = CT.isEqual(
getArray(a.buffer, a.pexp, a.pexplen), getArray(b.buffer, b.pexp, b.pexplen)
)
let r7 = CT.isEqual(
getArray(a.buffer, a.pubk.n, a.pubk.nlen),
getArray(b.buffer, b.pubk.n, b.pubk.nlen),
)
let r8 = CT.isEqual(
getArray(a.buffer, a.pubk.e, a.pubk.elen),
getArray(b.buffer, b.pubk.e, b.pubk.elen),
)
if cast[int](a.seck.nBitlen) > 0:
let r1 = CT.isEqual(getArray(a.buffer, a.seck.p, a.seck.plen),
getArray(b.buffer, b.seck.p, b.seck.plen))
let r2 = CT.isEqual(getArray(a.buffer, a.seck.q, a.seck.qlen),
getArray(b.buffer, b.seck.q, b.seck.qlen))
let r3 = CT.isEqual(getArray(a.buffer, a.seck.dp, a.seck.dplen),
getArray(b.buffer, b.seck.dp, b.seck.dplen))
let r4 = CT.isEqual(getArray(a.buffer, a.seck.dq, a.seck.dqlen),
getArray(b.buffer, b.seck.dq, b.seck.dqlen))
let r5 = CT.isEqual(getArray(a.buffer, a.seck.iq, a.seck.iqlen),
getArray(b.buffer, b.seck.iq, b.seck.iqlen))
let r6 = CT.isEqual(getArray(a.buffer, a.pexp, a.pexplen),
getArray(b.buffer, b.pexp, b.pexplen))
let r7 = CT.isEqual(getArray(a.buffer, a.pubk.n, a.pubk.nlen),
getArray(b.buffer, b.pubk.n, b.pubk.nlen))
let r8 = CT.isEqual(getArray(a.buffer, a.pubk.e, a.pubk.elen),
getArray(b.buffer, b.pubk.e, b.pubk.elen))
r1 and r2 and r3 and r4 and r5 and r6 and r7 and r8
else:
true
@@ -732,17 +740,14 @@ proc `==`*(a, b: RsaPublicKey): bool =
elif isNil(b) and (not isNil(a)):
false
else:
let r1 = CT.isEqual(
getArray(a.buffer, a.key.n, a.key.nlen), getArray(b.buffer, b.key.n, b.key.nlen)
)
let r2 = CT.isEqual(
getArray(a.buffer, a.key.e, a.key.elen), getArray(b.buffer, b.key.e, b.key.elen)
)
let r1 = CT.isEqual(getArray(a.buffer, a.key.n, a.key.nlen),
getArray(b.buffer, b.key.n, b.key.nlen))
let r2 = CT.isEqual(getArray(a.buffer, a.key.e, a.key.elen),
getArray(b.buffer, b.key.e, b.key.elen))
(r1 and r2)
proc sign*[T: byte | char](
key: RsaPrivateKey, message: openArray[T]
): RsaResult[RsaSignature] {.gcsafe.} =
proc sign*[T: byte|char](key: RsaPrivateKey,
message: openArray[T]): RsaResult[RsaSignature] {.gcsafe.} =
## Get RSA PKCS1.5 signature of data ``message`` using SHA256 and private
## key ``key``.
if isNil(key):
@@ -761,16 +766,16 @@ proc sign*[T: byte | char](
kv.update(addr hc.vtable, nil, 0)
kv.out(addr hc.vtable, addr hash[0])
var oid = RsaOidSha256
let implRes =
impl(addr oid[0], addr hash[0], uint(len(hash)), addr key.seck, addr res.buffer[0])
let implRes = impl(addr oid[0],
addr hash[0], uint(len(hash)),
addr key.seck, addr res.buffer[0])
if implRes == 0:
err(RsaSignatureError)
else:
ok(res)
proc verify*[T: byte | char](
sig: RsaSignature, message: openArray[T], pubkey: RsaPublicKey
): bool {.inline.} =
proc verify*[T: byte|char](sig: RsaSignature, message: openArray[T],
pubkey: RsaPublicKey): bool {.inline.} =
## Verify RSA signature ``sig`` using public key ``pubkey`` and data
## ``message``.
##
@@ -790,13 +795,8 @@ proc verify*[T: byte | char](
kv.update(addr hc.vtable, nil, 0)
kv.out(addr hc.vtable, addr hash[0])
var oid = RsaOidSha256
let res = impl(
addr sig.buffer[0],
uint(len(sig.buffer)),
addr oid[0],
uint(len(check)),
addr pubkey.key,
addr check[0],
)
let res = impl(addr sig.buffer[0], uint(len(sig.buffer)),
addr oid[0],
uint(len(check)), addr pubkey.key, addr check[0])
if res == 1:
result = equalMem(addr check[0], addr hash[0], len(hash))

View File

@@ -7,18 +7,26 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import bearssl/rand
import secp256k1, stew/[byteutils, results], nimcrypto/[hash, sha2]
import
secp256k1,
stew/[byteutils, results],
nimcrypto/[hash, sha2]
export sha2, results, rand
const
SkRawPrivateKeySize* = 256 div 8 ## Size of private key in octets (bytes)
SkRawPrivateKeySize* = 256 div 8
## Size of private key in octets (bytes)
SkRawSignatureSize* = SkRawPrivateKeySize * 2 + 1
## Size of signature in octets (bytes)
SkRawPublicKeySize* = SkRawPrivateKeySize + 1 ## Size of public key in octets (bytes)
SkRawPublicKeySize* = SkRawPrivateKeySize + 1
## Size of public key in octets (bytes)
# This is extremely confusing but it's to avoid.. confusion between Eth standard and Secp standard
type
@@ -27,6 +35,9 @@ type
SkSignature* = distinct secp256k1.SkSignature
SkKeyPair* = distinct secp256k1.SkKeyPair
template pubkey*(v: SkKeyPair): SkPublicKey = SkPublicKey(secp256k1.SkKeyPair(v).pubkey)
template seckey*(v: SkKeyPair): SkPrivateKey = SkPrivateKey(secp256k1.SkKeyPair(v).seckey)
proc random*(t: typedesc[SkPrivateKey], rng: var HmacDrbgContext): SkPrivateKey =
#TODO is there a better way?
var rngPtr = addr rng
@@ -51,31 +62,31 @@ template pubkey*(v: SkKeyPair): SkPublicKey =
proc init*(key: var SkPrivateKey, data: openArray[byte]): SkResult[void] =
## Initialize Secp256k1 `private key` ``key`` from raw binary
## representation ``data``.
key = SkPrivateKey(?secp256k1.SkSecretKey.fromRaw(data))
key = SkPrivateKey(? secp256k1.SkSecretKey.fromRaw(data))
ok()
proc init*(key: var SkPrivateKey, data: string): SkResult[void] =
## Initialize Secp256k1 `private key` ``key`` from hexadecimal string
## representation ``data``.
key = SkPrivateKey(?secp256k1.SkSecretKey.fromHex(data))
key = SkPrivateKey(? secp256k1.SkSecretKey.fromHex(data))
ok()
proc init*(key: var SkPublicKey, data: openArray[byte]): SkResult[void] =
## Initialize Secp256k1 `public key` ``key`` from raw binary
## representation ``data``.
key = SkPublicKey(?secp256k1.SkPublicKey.fromRaw(data))
key = SkPublicKey(? secp256k1.SkPublicKey.fromRaw(data))
ok()
proc init*(key: var SkPublicKey, data: string): SkResult[void] =
## Initialize Secp256k1 `public key` ``key`` from hexadecimal string
## representation ``data``.
key = SkPublicKey(?secp256k1.SkPublicKey.fromHex(data))
key = SkPublicKey(? secp256k1.SkPublicKey.fromHex(data))
ok()
proc init*(sig: var SkSignature, data: openArray[byte]): SkResult[void] =
## Initialize Secp256k1 `signature` ``sig`` from raw binary
## representation ``data``.
sig = SkSignature(?secp256k1.SkSignature.fromDer(data))
sig = SkSignature(? secp256k1.SkSignature.fromDer(data))
ok()
proc init*(sig: var SkSignature, data: string): SkResult[void] =
@@ -146,7 +157,7 @@ proc toBytes*(key: SkPrivateKey, data: var openArray[byte]): SkResult[int] =
## Procedure returns number of bytes (octets) needed to store
## Secp256k1 private key.
if len(data) >= SkRawPrivateKeySize:
data[0 ..< SkRawPrivateKeySize] = SkSecretKey(key).toRaw()
data[0..<SkRawPrivateKeySize] = SkSecretKey(key).toRaw()
ok(SkRawPrivateKeySize)
else:
err("secp: Not enough bytes")
@@ -158,7 +169,7 @@ proc toBytes*(key: SkPublicKey, data: var openArray[byte]): SkResult[int] =
## Procedure returns number of bytes (octets) needed to store
## Secp256k1 public key.
if len(data) >= SkRawPublicKeySize:
data[0 ..< SkRawPublicKeySize] = secp256k1.SkPublicKey(key).toRawCompressed()
data[0..<SkRawPublicKeySize] = secp256k1.SkPublicKey(key).toRawCompressed()
ok(SkRawPublicKeySize)
else:
err("secp: Not enough bytes")
@@ -185,28 +196,22 @@ proc getBytes*(sig: SkSignature): seq[byte] {.inline.} =
let length = toBytes(sig, result)
result.setLen(length)
proc sign*[T: byte | char](key: SkPrivateKey, msg: openArray[T]): SkSignature =
proc sign*[T: byte|char](key: SkPrivateKey, msg: openArray[T]): SkSignature =
## Sign message `msg` using private key `key` and return signature object.
let h = sha256.digest(msg)
SkSignature(sign(SkSecretKey(key), SkMessage(h.data)))
proc verify*[T: byte | char](
sig: SkSignature, msg: openArray[T], key: SkPublicKey
): bool =
proc verify*[T: byte|char](sig: SkSignature, msg: openArray[T],
key: SkPublicKey): bool =
let h = sha256.digest(msg)
verify(secp256k1.SkSignature(sig), SkMessage(h.data), secp256k1.SkPublicKey(key))
func clear*(key: var SkPrivateKey) =
clear(secp256k1.SkSecretKey(key))
func clear*(key: var SkPrivateKey) = clear(secp256k1.SkSecretKey(key))
func `$`*(key: SkPrivateKey): string =
$secp256k1.SkSecretKey(key)
func `$`*(key: SkPublicKey): string =
$secp256k1.SkPublicKey(key)
func `$`*(key: SkSignature): string =
$secp256k1.SkSignature(key)
func `$`*(key: SkKeyPair): string =
$secp256k1.SkKeyPair(key)
func `$`*(key: SkPrivateKey): string = $secp256k1.SkSecretKey(key)
func `$`*(key: SkPublicKey): string = $secp256k1.SkPublicKey(key)
func `$`*(key: SkSignature): string = $secp256k1.SkSignature(key)
func `$`*(key: SkKeyPair): string = $secp256k1.SkKeyPair(key)
func `==`*(a, b: SkPrivateKey): bool =
secp256k1.SkSecretKey(a) == secp256k1.SkSecretKey(b)

File diff suppressed because it is too large Load Diff

View File

@@ -7,29 +7,31 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
## This module implements Pool of StreamTransport.
import chronos
const DefaultPoolSize* = 8 ## Default pool size
const
DefaultPoolSize* = 8
## Default pool size
type
ConnectionFlags = enum
None
Busy
None, Busy
PoolItem = object
transp*: StreamTransport
flags*: set[ConnectionFlags]
PoolState = enum
Connecting
Connected
Closing
Closed
Connecting, Connected, Closing, Closed
TransportPool* = ref object ## Transports pool object
TransportPool* = ref object
## Transports pool object
transports: seq[PoolItem]
busyCount: int
state: PoolState
@@ -46,16 +48,13 @@ proc waitAll[T](futs: seq[Future[T]]): Future[void] =
dec(counter)
if counter == 0:
retFuture.complete()
for fut in futs:
fut.addCallback(cb)
return retFuture
proc newPool*(
address: TransportAddress,
poolsize: int = DefaultPoolSize,
bufferSize = DefaultStreamBufferSize,
): Future[TransportPool] {.async.} =
proc newPool*(address: TransportAddress, poolsize: int = DefaultPoolSize,
bufferSize = DefaultStreamBufferSize,
): Future[TransportPool] {.async.} =
## Establish pool of connections to address ``address`` with size
## ``poolsize``.
var pool = new TransportPool
@@ -63,12 +62,12 @@ proc newPool*(
pool.transports = newSeq[PoolItem](poolsize)
var conns = newSeq[Future[StreamTransport]](poolsize)
pool.state = Connecting
for i in 0 ..< poolsize:
for i in 0..<poolsize:
conns[i] = connect(address, bufferSize)
# Waiting for all connections to be established.
await waitAll(conns)
# Checking connections and preparing pool.
for i in 0 ..< poolsize:
for i in 0..<poolsize:
if conns[i].failed:
raise conns[i].error
else:
@@ -138,7 +137,7 @@ proc close*(pool: TransportPool) {.async.} =
await pool.join()
# Closing all transports
var pending = newSeq[Future[void]](len(pool.transports))
for i in 0 ..< len(pool.transports):
for i in 0..<len(pool.transports):
let transp = pool.transports[i].transp
transp.close()
pending[i] = transp.join()

View File

@@ -25,40 +25,34 @@
## 5. LocalAddress: optional bytes
## 6. RemoteAddress: optional bytes
## 7. Message: required bytes
import os
import os, options
import nimcrypto/utils, stew/endians2
import
protobuf/minprotobuf,
stream/connection,
protocols/secure/secure,
multiaddress,
peerid,
varint,
muxers/mplex/coder
import protobuf/minprotobuf, stream/connection, protocols/secure/secure,
multiaddress, peerid, varint, muxers/mplex/coder
from times import
getTime, toUnix, fromUnix, nanosecond, format, Time, NanosecondRange, initTime
from times import getTime, toUnix, fromUnix, nanosecond, format, Time,
NanosecondRange, initTime
from strutils import toHex, repeat
export peerid, multiaddress
export peerid, options, multiaddress
type
FlowDirection* = enum
Outgoing
Incoming
Outgoing, Incoming
ProtoMessage* = object
timestamp*: uint64
direction*: FlowDirection
message*: seq[byte]
seqID*: Opt[uint64]
mtype*: Opt[uint64]
local*: Opt[MultiAddress]
remote*: Opt[MultiAddress]
seqID*: Option[uint64]
mtype*: Option[uint64]
local*: Option[MultiAddress]
remote*: Option[MultiAddress]
const libp2p_dump_dir* {.strdefine.} = "nim-libp2p"
## default directory where all the dumps will be stored, if the path
## relative it will be created in home directory. You can overload this path
## using ``-d:libp2p_dump_dir=<otherpath>``.
const
libp2p_dump_dir* {.strdefine.} = "nim-libp2p"
## default directory where all the dumps will be stored, if the path
## relative it will be created in home directory. You can overload this path
## using ``-d:libp2p_dump_dir=<otherpath>``.
proc getTimestamp(): uint64 =
## This procedure is present because `stdlib.times` missing it.
@@ -71,14 +65,14 @@ proc getTimedate(value: uint64): string =
let time = initTime(int64(value div 1_000_000_000), value mod 1_000_000_000)
time.format("yyyy-MM-dd HH:mm:ss'.'fffzzz")
proc dumpMessage*(conn: SecureConn, direction: FlowDirection, data: openArray[byte]) =
proc dumpMessage*(conn: SecureConn, direction: FlowDirection,
data: openArray[byte]) =
## Store unencrypted message ``data`` to dump file, all the metadata will be
## extracted from ``conn`` instance.
var pb = initProtoBuffer(options = {WithVarintLength})
pb.write(2, getTimestamp())
pb.write(4, uint64(direction))
conn.observedAddr.withValue(oaddr):
pb.write(6, oaddr)
pb.write(6, conn.observedAddr)
pb.write(7, data)
pb.finish()
@@ -92,7 +86,7 @@ proc dumpMessage*(conn: SecureConn, direction: FlowDirection, data: openArray[by
# This is debugging procedure so it should not generate any exceptions,
# and we going to return at every possible OS error.
if not (dirExists(dirName)):
if not(dirExists(dirName)):
try:
createDir(dirName)
except CatchableError:
@@ -106,7 +100,7 @@ proc dumpMessage*(conn: SecureConn, direction: FlowDirection, data: openArray[by
finally:
close(handle)
proc decodeDumpMessage*(data: openArray[byte]): Opt[ProtoMessage] =
proc decodeDumpMessage*(data: openArray[byte]): Option[ProtoMessage] =
## Decode protobuf's message ProtoMessage from array of bytes ``data``.
var
pb = initProtoBuffer(data)
@@ -114,12 +108,13 @@ proc decodeDumpMessage*(data: openArray[byte]): Opt[ProtoMessage] =
ma1, ma2: MultiAddress
pmsg: ProtoMessage
let
r2 = pb.getField(2, pmsg.timestamp)
r4 = pb.getField(4, value)
r7 = pb.getField(7, pmsg.message)
if not r2.get(false) or not r4.get(false) or not r7.get(false):
return Opt.none(ProtoMessage)
let res2 = pb.getField(2, pmsg.timestamp)
if res2.isErr() or not(res2.get()):
return none[ProtoMessage]()
let res4 = pb.getField(4, value)
if res4.isErr() or not(res4.get()):
return none[ProtoMessage]()
# `case` statement could not work here with an error "selector must be of an
# ordinal type, float or string"
@@ -129,27 +124,30 @@ proc decodeDumpMessage*(data: openArray[byte]): Opt[ProtoMessage] =
elif value == uint64(Incoming):
Incoming
else:
return Opt.none(ProtoMessage)
return none[ProtoMessage]()
let r1 = pb.getField(1, value)
if r1.get(false):
pmsg.seqID = Opt.some(value)
let res7 = pb.getField(7, pmsg.message)
if res7.isErr() or not(res7.get()):
return none[ProtoMessage]()
let r3 = pb.getField(3, value)
if r3.get(false):
pmsg.mtype = Opt.some(value)
value = 0'u64
let res1 = pb.getField(1, value)
if res1.isOk() and res1.get():
pmsg.seqID = some(value)
value = 0'u64
let res3 = pb.getField(3, value)
if res3.isOk() and res3.get():
pmsg.mtype = some(value)
let res5 = pb.getField(5, ma1)
if res5.isOk() and res5.get():
pmsg.local = some(ma1)
let res6 = pb.getField(6, ma2)
if res6.isOk() and res6.get():
pmsg.remote = some(ma2)
let
r5 = pb.getField(5, ma1)
r6 = pb.getField(6, ma2)
if r5.get(false):
pmsg.local = Opt.some(ma1)
if r6.get(false):
pmsg.remote = Opt.some(ma2)
some(pmsg)
Opt.some(pmsg)
iterator messages*(data: seq[byte]): Opt[ProtoMessage] =
iterator messages*(data: seq[byte]): Option[ProtoMessage] =
## Iterate over sequence of bytes and decode all the ``ProtoMessage``
## messages we found.
var value: uint64
@@ -158,11 +156,13 @@ iterator messages*(data: seq[byte]): Opt[ProtoMessage] =
while offset < len(data):
value = 0
size = 0
let res = PB.getUVarint(data.toOpenArray(offset, len(data) - 1), size, value)
let res = PB.getUVarint(data.toOpenArray(offset, len(data) - 1),
size, value)
if res.isOk():
if (value > 0'u64) and (value < uint64(len(data) - offset)):
offset += size
yield decodeDumpMessage(data.toOpenArray(offset, offset + int(value) - 1))
yield decodeDumpMessage(data.toOpenArray(offset,
offset + int(value) - 1))
# value is previously checked to be less then len(data) which is `int`.
offset += int(value)
else:
@@ -182,15 +182,10 @@ proc dumpHex*(pbytes: openArray[byte], groupBy = 1, ascii = true): string =
for k in 0 ..< groupBy:
let ch = pbytes[offset + k]
ascii.add(
if ord(ch) > 31 and ord(ch) < 127:
char(ch)
else:
'.'
)
ascii.add(if ord(ch) > 31 and ord(ch) < 127: char(ch) else: '.')
let item =
case groupBy
case groupBy:
of 1:
toHex(pbytes[offset])
of 2:
@@ -212,7 +207,8 @@ proc dumpHex*(pbytes: openArray[byte], groupBy = 1, ascii = true): string =
res.add("\p")
if (offset mod 16) != 0:
let spacesCount = ((16 - (offset mod 16)) div groupBy) * (groupBy * 2 + 1) + 1
let spacesCount = ((16 - (offset mod 16)) div groupBy) *
(groupBy * 2 + 1) + 1
res = res & repeat(' ', spacesCount)
res = res & ascii
@@ -240,28 +236,31 @@ proc toString*(msg: ProtoMessage, dump = true): string =
var res = getTimedate(msg.timestamp)
let direction =
case msg.direction
of Incoming: " << "
of Outgoing: " >> "
let address = block:
let local = block:
msg.local.withValue(loc):
"[" & $loc & "]"
else:
"[LOCAL]"
let remote = block:
msg.remote.withValue(rem):
"[" & $rem & "]"
else:
"[REMOTE]"
local & direction & remote
let seqid = block:
msg.seqID.withValue(seqid):
"seqID = " & $seqid & " "
of Incoming:
" << "
of Outgoing:
" >> "
let address =
block:
let local =
if msg.local.isSome():
"[" & $(msg.local.get()) & "]"
else:
"[LOCAL]"
let remote =
if msg.remote.isSome():
"[" & $(msg.remote.get()) & "]"
else:
"[REMOTE]"
local & direction & remote
let seqid =
if msg.seqID.isSome():
"seqID = " & $(msg.seqID.get()) & " "
else:
""
let mtype = block:
msg.mtype.withValue(typ):
"type = " & $typ & " "
let mtype =
if msg.mtype.isSome():
"type = " & $(msg.mtype.get()) & " "
else:
""
res.add(" ")

View File

@@ -7,24 +7,28 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos
import stew/results
import peerid, stream/connection, transports/transport
import peerid,
stream/connection,
transports/transport
export results
type Dial* = ref object of RootObj
type
Dial* = ref object of RootObj
method connect*(
self: Dial,
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
dir = Direction.Out,
) {.async, base.} =
self: Dial,
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true) {.async, base.} =
## connect remote peer without negotiating
## a protocol
##
@@ -32,15 +36,18 @@ method connect*(
doAssert(false, "Not implemented!")
method connect*(
self: Dial, address: MultiAddress, allowUnknownPeerId = false
): Future[PeerId] {.async, base.} =
self: Dial,
address: MultiAddress,
allowUnknownPeerId = false): Future[PeerId] {.async, base.} =
## Connects to a peer and retrieve its PeerId
doAssert(false, "Not implemented!")
method dial*(
self: Dial, peerId: PeerId, protos: seq[string]
): Future[Connection] {.async, base.} =
self: Dial,
peerId: PeerId,
protos: seq[string],
): Future[Connection] {.async, base.} =
## create a protocol stream over an
## existing connection
##
@@ -48,22 +55,24 @@ method dial*(
doAssert(false, "Not implemented!")
method dial*(
self: Dial,
peerId: PeerId,
addrs: seq[MultiAddress],
protos: seq[string],
forceDial = false,
): Future[Connection] {.async, base.} =
self: Dial,
peerId: PeerId,
addrs: seq[MultiAddress],
protos: seq[string],
forceDial = false): Future[Connection] {.async, base.} =
## create a protocol stream and establish
## a connection if one doesn't exist already
##
doAssert(false, "Not implemented!")
method addTransport*(self: Dial, transport: Transport) {.base.} =
method addTransport*(
self: Dial,
transport: Transport) {.base.} =
doAssert(false, "Not implemented!")
method tryDial*(
self: Dial, peerId: PeerId, addrs: seq[MultiAddress]
): Future[Opt[MultiAddress]] {.async, base.} =
self: Dial,
peerId: PeerId,
addrs: seq[MultiAddress]): Future[Opt[MultiAddress]] {.async, base.} =
doAssert(false, "Not implemented!")

View File

@@ -7,25 +7,24 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/tables
import std/[sugar, tables, sequtils]
import stew/results
import pkg/[chronos, chronicles, metrics]
import pkg/[chronos,
chronicles,
metrics]
import
dial,
peerid,
peerinfo,
peerstore,
multicodec,
muxers/muxer,
multistream,
connmanager,
stream/connection,
transports/transport,
nameresolving/nameresolver,
upgrademngrs/upgrade,
errors
import dial,
peerid,
peerinfo,
multicodec,
multistream,
connmanager,
stream/connection,
transports/transport,
nameresolving/nameresolver,
upgrademngrs/upgrade,
errors
export dial, errors, results
@@ -35,79 +34,74 @@ logScope:
declareCounter(libp2p_total_dial_attempts, "total attempted dials")
declareCounter(libp2p_successful_dials, "dialed successful peers")
declareCounter(libp2p_failed_dials, "failed dials")
declareCounter(libp2p_failed_upgrades_outgoing, "outgoing connections failed upgrades")
type
DialFailedError* = object of LPError
Dialer* = ref object of Dial
localPeerId*: PeerId
ms: MultistreamSelect
connManager: ConnManager
dialLock: Table[PeerId, AsyncLock]
transports: seq[Transport]
peerStore: PeerStore
nameResolver: NameResolver
proc dialAndUpgrade(
self: Dialer,
peerId: Opt[PeerId],
hostname: string,
address: MultiAddress,
dir = Direction.Out,
): Future[Muxer] {.async.} =
self: Dialer,
peerId: Opt[PeerId],
hostname: string,
address: MultiAddress):
Future[Connection] {.async.} =
for transport in self.transports: # for each transport
if transport.handles(address): # check if it can dial it
trace "Dialing address", address, peerId = peerId.get(default(PeerId)), hostname
if transport.handles(address): # check if it can dial it
trace "Dialing address", address, peerId, hostname
let dialed =
try:
libp2p_total_dial_attempts.inc()
await transport.dial(hostname, address, peerId)
except CancelledError as exc:
debug "Dialing canceled", err = exc.msg, peerId = peerId.get(default(PeerId))
debug "Dialing canceled", msg = exc.msg, peerId
raise exc
except CatchableError as exc:
debug "Dialing failed", err = exc.msg, peerId = peerId.get(default(PeerId))
debug "Dialing failed", msg = exc.msg, peerId
libp2p_failed_dials.inc()
return nil # Try the next address
# also keep track of the connection's bottom unsafe transport direction
# required by gossipsub scoring
dialed.transportDir = Direction.Out
libp2p_successful_dials.inc()
let mux =
let conn =
try:
# This is for the very specific case of a simultaneous dial during DCUtR. In this case, both sides will have
# an Outbound direction at the transport level. Therefore we update the DCUtR initiator transport direction to Inbound.
# The if below is more general and might handle other use cases in the future.
if dialed.dir != dir:
dialed.dir = dir
await transport.upgrade(dialed, peerId)
except CancelledError as exc:
await dialed.close()
raise exc
await transport.upgradeOutgoing(dialed, peerId)
except CatchableError as exc:
# If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again
await dialed.close()
debug "Connection upgrade failed",
err = exc.msg, peerId = peerId.get(default(PeerId))
if dialed.dir == Direction.Out:
debug "Upgrade failed", msg = exc.msg, peerId
if exc isnot CancelledError:
libp2p_failed_upgrades_outgoing.inc()
else:
libp2p_failed_upgrades_incoming.inc()
# Try other address
return nil
doAssert not isNil(mux), "connection died after upgrade " & $dialed.dir
debug "Dial successful", peerId = mux.connection.peerId
return mux
doAssert not isNil(conn), "connection died after upgradeOutgoing"
debug "Dial successful", conn, peerId = conn.peerId
return conn
return nil
proc expandDnsAddr(
self: Dialer, peerId: Opt[PeerId], address: MultiAddress
): Future[seq[(MultiAddress, Opt[PeerId])]] {.async.} =
if not DNSADDR.matchPartial(address):
return @[(address, peerId)]
self: Dialer,
peerId: Opt[PeerId],
address: MultiAddress): Future[seq[(MultiAddress, Opt[PeerId])]] {.async.} =
if not DNSADDR.matchPartial(address): return @[(address, peerId)]
if isNil(self.nameResolver):
info "Can't resolve DNSADDR without NameResolver", ma = address
info "Can't resolve DNSADDR without NameResolver", ma=address
return @[]
let
@@ -124,14 +118,17 @@ proc expandDnsAddr(
let
peerIdBytes = lastPart.protoArgument().tryGet()
addrPeerId = PeerId.init(peerIdBytes).tryGet()
result.add((resolvedAddress[0 ..^ 2].tryGet(), Opt.some(addrPeerId)))
result.add((resolvedAddress[0..^2].tryGet(), Opt.some(addrPeerId)))
else:
result.add((resolvedAddress, peerId))
proc dialAndUpgrade(
self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
): Future[Muxer] {.async.} =
debug "Dialing peer", peerId = peerId.get(default(PeerId)), addrs
self: Dialer,
peerId: Opt[PeerId],
addrs: seq[MultiAddress]):
Future[Connection] {.async.} =
debug "Dialing peer", peerId
for rawAddress in addrs:
# resolve potential dnsaddr
@@ -142,32 +139,37 @@ proc dialAndUpgrade(
let
hostname = expandedAddress.getHostname()
resolvedAddresses =
if isNil(self.nameResolver):
@[expandedAddress]
else:
await self.nameResolver.resolveMAddress(expandedAddress)
if isNil(self.nameResolver): @[expandedAddress]
else: await self.nameResolver.resolveMAddress(expandedAddress)
for resolvedAddress in resolvedAddresses:
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, dir)
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress)
if not isNil(result):
return result
proc tryReusingConnection(self: Dialer, peerId: PeerId): Opt[Muxer] =
let muxer = self.connManager.selectMuxer(peerId)
if muxer == nil:
return Opt.none(Muxer)
proc tryReusingConnection(self: Dialer, peerId: PeerId): Future[Opt[Connection]] {.async.} =
var conn = self.connManager.selectConn(peerId)
if conn == nil:
return Opt.none(Connection)
trace "Reusing existing connection", muxer, direction = $muxer.connection.dir
return Opt.some(muxer)
if conn.atEof or conn.closed:
# This connection should already have been removed from the connection
# manager - it's essentially a bug that we end up here - we'll fail
# for now, hoping that this will clean themselves up later...
warn "dead connection in connection manager", conn
await conn.close()
raise newException(DialFailedError, "Zombie connection encountered")
trace "Reusing existing connection", conn, direction = $conn.dir
return Opt.some(conn)
proc internalConnect(
self: Dialer,
peerId: Opt[PeerId],
addrs: seq[MultiAddress],
forceDial: bool,
reuseConnection = true,
dir = Direction.Out,
): Future[Muxer] {.async.} =
self: Dialer,
peerId: Opt[PeerId],
addrs: seq[MultiAddress],
forceDial: bool,
reuseConnection = true):
Future[Connection] {.async.} =
if Opt.some(self.localPeerId) == peerId:
raise newException(CatchableError, "can't dial self!")
@@ -176,47 +178,43 @@ proc internalConnect(
try:
await lock.acquire()
if reuseConnection:
peerId.withValue(peerId):
self.tryReusingConnection(peerId).withValue(mux):
return mux
if peerId.isSome and reuseConnection:
let connOpt = await self.tryReusingConnection(peerId.get())
if connOpt.isSome:
return connOpt.get()
let slot = self.connManager.getOutgoingSlot(forceDial)
let muxed =
let conn =
try:
await self.dialAndUpgrade(peerId, addrs, dir)
await self.dialAndUpgrade(peerId, addrs)
except CatchableError as exc:
slot.release()
raise exc
slot.trackMuxer(muxed)
if isNil(muxed): # None of the addresses connected
slot.trackConnection(conn)
if isNil(conn): # None of the addresses connected
raise newException(DialFailedError, "Unable to establish outgoing link")
try:
self.connManager.storeMuxer(muxed)
await self.peerStore.identify(muxed)
await self.connManager.triggerPeerEvents(
muxed.connection.peerId,
PeerEvent(kind: PeerEventKind.Identified, initiator: true),
)
except CatchableError as exc:
trace "Failed to finish outgoung upgrade", err = exc.msg
await muxed.close()
raise exc
# A disconnect could have happened right after
# we've added the connection so we check again
# to prevent races due to that.
if conn.closed() or conn.atEof():
# This can happen when the other ends drops us
# before we get a chance to return the connection
# back to the dialer.
trace "Connection dead on arrival", conn
raise newLPStreamClosedError()
return muxed
return conn
finally:
if lock.locked():
lock.release()
method connect*(
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
dir = Direction.Out,
) {.async.} =
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true) {.async.} =
## connect remote peer without negotiating
## a protocol
##
@@ -224,32 +222,34 @@ method connect*(
if self.connManager.connCount(peerId) > 0 and reuseConnection:
return
discard
await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, dir)
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection)
method connect*(
self: Dialer, address: MultiAddress, allowUnknownPeerId = false
): Future[PeerId] {.async.} =
self: Dialer,
address: MultiAddress,
allowUnknownPeerId = false): Future[PeerId] {.async.} =
## Connects to a peer and retrieve its PeerId
parseFullAddress(address).toOpt().withValue(fullAddress):
return (
await self.internalConnect(Opt.some(fullAddress[0]), @[fullAddress[1]], false)
).connection.peerId
if allowUnknownPeerId == false:
raise newException(
DialFailedError, "Address without PeerID and unknown peer id disabled!"
)
return
(await self.internalConnect(Opt.none(PeerId), @[address], false)).connection.peerId
let fullAddress = parseFullAddress(address)
if fullAddress.isOk:
return (await self.internalConnect(
Opt.some(fullAddress.get()[0]),
@[fullAddress.get()[1]],
false)).peerId
else:
if allowUnknownPeerId == false:
raise newException(DialFailedError, "Address without PeerID and unknown peer id disabled!")
return (await self.internalConnect(
Opt.none(PeerId),
@[address],
false)).peerId
proc negotiateStream(
self: Dialer, conn: Connection, protos: seq[string]
): Future[Connection] {.async.} =
self: Dialer,
conn: Connection,
protos: seq[string]): Future[Connection] {.async.} =
trace "Negotiating stream", conn, protos
let selected = await MultistreamSelect.select(conn, protos)
let selected = await self.ms.select(conn, protos)
if not protos.contains(selected):
await conn.closeWithEOF()
raise newException(DialFailedError, "Unable to select sub-protocol " & $protos)
@@ -257,8 +257,9 @@ proc negotiateStream(
return conn
method tryDial*(
self: Dialer, peerId: PeerId, addrs: seq[MultiAddress]
): Future[Opt[MultiAddress]] {.async.} =
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress]): Future[Opt[MultiAddress]] {.async.} =
## Create a protocol stream in order to check
## if a connection is possible.
## Doesn't use the Connection Manager to save it.
@@ -266,19 +267,20 @@ method tryDial*(
trace "Check if it can dial", peerId, addrs
try:
let mux = await self.dialAndUpgrade(Opt.some(peerId), addrs)
if mux.isNil():
let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs)
if conn.isNil():
raise newException(DialFailedError, "No valid multiaddress")
await mux.close()
return mux.connection.observedAddr
await conn.close()
return conn.observedAddr
except CancelledError as exc:
raise exc
except CatchableError as exc:
raise newException(DialFailedError, exc.msg)
method dial*(
self: Dialer, peerId: PeerId, protos: seq[string]
): Future[Connection] {.async.} =
self: Dialer,
peerId: PeerId,
protos: seq[string]): Future[Connection] {.async.} =
## create a protocol stream over an
## existing connection
##
@@ -291,25 +293,24 @@ method dial*(
return await self.negotiateStream(stream, protos)
method dial*(
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress],
protos: seq[string],
forceDial = false,
): Future[Connection] {.async.} =
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress],
protos: seq[string],
forceDial = false): Future[Connection] {.async.} =
## create a protocol stream and establish
## a connection if one doesn't exist already
##
var
conn: Muxer
conn: Connection
stream: Connection
proc cleanup() {.async.} =
if not (isNil(stream)):
if not(isNil(stream)):
await stream.closeWithEOF()
if not (isNil(conn)):
if not(isNil(conn)):
await conn.close()
try:
@@ -319,7 +320,8 @@ method dial*(
stream = await self.connManager.getStream(conn)
if isNil(stream):
raise newException(DialFailedError, "Couldn't get muxed stream")
raise newException(DialFailedError,
"Couldn't get muxed stream")
return await self.negotiateStream(stream, protos)
except CancelledError as exc:
@@ -327,7 +329,7 @@ method dial*(
await cleanup()
raise exc
except CatchableError as exc:
debug "Error dialing", conn, err = exc.msg
debug "Error dialing", conn, msg = exc.msg
await cleanup()
raise exc
@@ -335,17 +337,15 @@ method addTransport*(self: Dialer, t: Transport) =
self.transports &= t
proc new*(
T: type Dialer,
localPeerId: PeerId,
connManager: ConnManager,
peerStore: PeerStore,
transports: seq[Transport],
nameResolver: NameResolver = nil,
): Dialer =
T(
localPeerId: localPeerId,
T: type Dialer,
localPeerId: PeerId,
connManager: ConnManager,
transports: seq[Transport],
ms: MultistreamSelect,
nameResolver: NameResolver = nil): Dialer =
T(localPeerId: localPeerId,
connManager: connManager,
transports: transports,
peerStore: peerStore,
nameResolver: nameResolver,
)
ms: ms,
nameResolver: nameResolver)

View File

@@ -7,7 +7,10 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/sequtils
import chronos, chronicles, stew/results
@@ -15,7 +18,7 @@ import ../errors
type
BaseAttr = ref object of RootObj
comparator: proc(f, c: BaseAttr): bool {.gcsafe, raises: [].}
comparator: proc(f, c: BaseAttr): bool {.gcsafe, raises: [Defect].}
Attribute[T] = ref object of BaseAttr
value: T
@@ -33,13 +36,12 @@ proc ofType*[T](f: BaseAttr, _: type[T]): bool =
proc to*[T](f: BaseAttr, _: type[T]): T =
Attribute[T](f).value
proc add*[T](pa: var PeerAttributes, value: T) =
pa.attributes.add(
Attribute[T](
proc add*[T](pa: var PeerAttributes,
value: T) =
pa.attributes.add(Attribute[T](
value: value,
comparator: proc(f: BaseAttr, c: BaseAttr): bool =
f.ofType(T) and c.ofType(T) and f.to(T) == c.to(T)
,
)
)
@@ -58,9 +60,8 @@ proc `{}`*[T](pa: PeerAttributes, t: typedesc[T]): Opt[T] =
return Opt.some(f.to(T))
Opt.none(T)
proc `[]`*[T](pa: PeerAttributes, t: typedesc[T]): T {.raises: [KeyError].} =
pa{T}.valueOr:
raise newException(KeyError, "Attritute not found")
proc `[]`*[T](pa: PeerAttributes, t: typedesc[T]): T {.raises: [Defect, KeyError].} =
pa{T}.valueOr: raise newException(KeyError, "Attritute not found")
proc match*(pa, candidate: PeerAttributes): bool =
for f in pa.attributes:
@@ -72,7 +73,7 @@ proc match*(pa, candidate: PeerAttributes): bool =
return true
type
PeerFoundCallback* = proc(pa: PeerAttributes) {.raises: [], gcsafe.}
PeerFoundCallback* = proc(pa: PeerAttributes) {.raises: [Defect], gcsafe.}
DiscoveryInterface* = ref object of RootObj
onPeerFound*: PeerFoundCallback
@@ -103,7 +104,7 @@ type
proc add*(dm: DiscoveryManager, di: DiscoveryInterface) =
dm.interfaces &= di
di.onPeerFound = proc(pa: PeerAttributes) =
di.onPeerFound = proc (pa: PeerAttributes) =
for query in dm.queries:
if query.attr.match(pa):
try:
@@ -124,15 +125,20 @@ proc request*[T](dm: DiscoveryManager, value: T): DiscoveryQuery =
pa.add(value)
return dm.request(pa)
proc advertise*[T](dm: DiscoveryManager, value: T) =
proc advertise*(dm: DiscoveryManager, pa: PeerAttributes) =
for i in dm.interfaces:
i.toAdvertise.add(value)
i.toAdvertise = pa
if i.advertiseLoop.isNil:
i.advertisementUpdated = newAsyncEvent()
i.advertiseLoop = i.advertise()
else:
i.advertisementUpdated.fire()
proc advertise*[T](dm: DiscoveryManager, value: T) =
var pa: PeerAttributes
pa.add(value)
dm.advertise(pa)
template forEach*(query: DiscoveryQuery, code: untyped) =
## Will execute `code` for each discovered peer. The
## peer attritubtes are available through the variable
@@ -141,10 +147,8 @@ template forEach*(query: DiscoveryQuery, code: untyped) =
proc forEachInternal(q: DiscoveryQuery) {.async.} =
while true:
let peer {.inject.} =
try:
await q.getPeer()
except DiscoveryFinished:
return
try: await q.getPeer()
except DiscoveryFinished: return
code
asyncSpawn forEachInternal(query)
@@ -152,15 +156,13 @@ template forEach*(query: DiscoveryQuery, code: untyped) =
proc stop*(query: DiscoveryQuery) =
query.finished = true
for r in query.futs:
if not r.finished():
r.cancel()
if not r.finished(): r.cancel()
proc stop*(dm: DiscoveryManager) =
for q in dm.queries:
q.stop()
for i in dm.interfaces:
if isNil(i.advertiseLoop):
continue
if isNil(i.advertiseLoop): continue
i.advertiseLoop.cancel()
proc getPeer*(query: DiscoveryQuery): Future[PeerAttributes] {.async.} =

View File

@@ -7,17 +7,22 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import sequtils
import chronos
import ./discoverymngr, ../protocols/rendezvous, ../peerid
import ./discoverymngr,
../protocols/rendezvous,
../peerid
type
RendezVousInterface* = ref object of DiscoveryInterface
rdv*: RendezVous
timeToRequest: Duration
timeToAdvertise: Duration
ttl: Duration
RdvNamespace* = distinct string
@@ -61,18 +66,12 @@ method advertise*(self: RendezVousInterface) {.async.} =
self.advertisementUpdated.clear()
for toAdv in toAdvertise:
try:
await self.rdv.advertise(toAdv, self.ttl)
except CatchableError as error:
debug "RendezVous advertise error: ", msg = error.msg
await self.rdv.advertise(toAdv, self.timeToAdvertise)
await sleepAsync(self.timeToAdvertise) or self.advertisementUpdated.wait()
proc new*(
T: typedesc[RendezVousInterface],
rdv: RendezVous,
ttr: Duration = 1.minutes,
tta: Duration = 1.minutes,
ttl: Duration = MinimumDuration,
): RendezVousInterface =
T(rdv: rdv, timeToRequest: ttr, timeToAdvertise: tta, ttl: ttl)
proc new*(T: typedesc[RendezVousInterface],
rdv: RendezVous,
ttr: Duration = 1.minutes,
tta: Duration = MinimumDuration): RendezVousInterface =
T(rdv: rdv, timeToRequest: ttr, timeToAdvertise: tta)

View File

@@ -19,30 +19,58 @@ func toException*(e: string): ref LPError =
# sadly nim needs more love for hygienic templates
# so here goes the macro, its based on the proc/template version
# and uses quote do so it's quite readable
# TODO https://github.com/nim-lang/Nim/issues/22936
macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped =
let nexclude = exclude.len
case nexclude
of 0:
quote:
quote do:
for res in `futs`:
if res.failed:
let exc = res.readError()
# We still don't abort but warn
debug "A future has failed, enable trace logging for details",
error = exc.name
trace "Exception message", msg = exc.msg, stack = getStackTrace()
debug "A future has failed, enable trace logging for details", error = exc.name
trace "Exception message", msg= exc.msg, stack = getStackTrace()
else:
quote:
quote do:
for res in `futs`:
block check:
if res.failed:
let exc = res.readError()
for i in 0 ..< `nexclude`:
for i in 0..<`nexclude`:
if exc of `exclude`[i]:
trace "A future has failed", error = exc.name, msg = exc.msg
trace "A future has failed", error=exc.name, msg=exc.msg
break check
# We still don't abort but warn
debug "A future has failed, enable trace logging for details",
error = exc.name
trace "Exception details", msg = exc.msg
debug "A future has failed, enable trace logging for details", error=exc.name
trace "Exception details", msg=exc.msg
proc allFuturesThrowing*[T](args: varargs[Future[T]]): Future[void] =
var futs: seq[Future[T]]
for fut in args:
futs &= fut
proc call() {.async.} =
var first: ref CatchableError = nil
futs = await allFinished(futs)
for fut in futs:
if fut.failed:
let err = fut.readError()
if err of Defect:
raise err
else:
if err of CancelledError:
raise err
if isNil(first):
first = err
if not isNil(first):
raise first
return call()
template tryAndWarn*(message: static[string]; body: untyped): untyped =
try:
body
except CancelledError as exc:
raise exc
except CatchableError as exc:
debug "An exception has ocurred, enable trace logging for details", name = exc.name, msg = message
trace "Exception details", exc = exc.msg

View File

@@ -9,35 +9,23 @@
## This module implements MultiAddress.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
{.push public.}
import pkg/chronos, chronicles
import std/[nativesockets, net, hashes]
import tables, strutils, sets
import
multicodec,
multihash,
multibase,
transcoder,
vbuffer,
peerid,
protobuf/minprotobuf,
errors,
utility
import pkg/chronos
import std/[nativesockets, hashes]
import tables, strutils, sets, stew/shims/net
import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
protobuf/minprotobuf, errors, utility
import stew/[base58, base32, endians2, results]
export results, minprotobuf, vbuffer, utility
logScope:
topics = "libp2p multiaddress"
type
MAKind* = enum
None
Fixed
Length
Path
Marker
None, Fixed, Length, Path, Marker
MAProtocol* = object
mcodec*: MultiCodec
@@ -46,12 +34,10 @@ type
coder*: Transcoder
MultiAddress* = object
data: VBuffer
data*: VBuffer
MaPatternOp* = enum
Eq
Or
And
Eq, Or, And
MaPattern* = object
operator*: MaPatternOp
@@ -77,10 +63,6 @@ const
IPPROTO_TCP = Protocol.IPPROTO_TCP
IPPROTO_UDP = Protocol.IPPROTO_UDP
proc data*(ma: MultiAddress): VBuffer =
## Returns the data buffer of the MultiAddress.
return ma.data
proc hash*(a: MultiAddress): Hash =
var h: Hash = 0
h = h !& hash(a.data.buffer)
@@ -94,7 +76,7 @@ proc ip4StB(s: string, vb: var VBuffer): bool =
if a.family == IpAddressFamily.IPv4:
vb.writeArray(a.address_v4)
result = true
except CatchableError:
except:
discard
proc ip4BtS(vb: var VBuffer, s: var string): bool =
@@ -117,7 +99,7 @@ proc ip6StB(s: string, vb: var VBuffer): bool =
if a.family == IpAddressFamily.IPv6:
vb.writeArray(a.address_v6)
result = true
except CatchableError:
except:
discard
proc ip6BtS(vb: var VBuffer, s: var string): bool =
@@ -133,40 +115,23 @@ proc ip6VB(vb: var VBuffer): bool =
if vb.readArray(a.address_v6) == 16:
result = true
template pathStringToBuffer(s: string, vb: var VBuffer): bool =
if len(s) > 0:
vb.writeSeq(s)
true
else:
false
template pathBufferToString(vb: var VBuffer, s: var string): bool =
s = ""
if (vb.readSeq(s) > 0) and (len(s) > 0): true else: false
template pathBufferToStringNoSlash(vb: var VBuffer, s: var string): bool =
s = ""
if (vb.readSeq(s) > 0) and (len(s) > 0) and (s.find('/') == -1): true else: false
template pathValidateBuffer(vb: var VBuffer): bool =
var s = ""
pathBufferToString(vb, s)
template pathValidateBufferNoSlash(vb: var VBuffer): bool =
var s = ""
pathBufferToStringNoSlash(vb, s)
proc ip6zoneStB(s: string, vb: var VBuffer): bool =
## IPv6 stringToBuffer() implementation.
pathStringToBuffer(s, vb)
if len(s) > 0:
vb.writeSeq(s)
result = true
proc ip6zoneBtS(vb: var VBuffer, s: var string): bool =
## IPv6 bufferToString() implementation.
pathBufferToStringNoSlash(vb, s)
if vb.readSeq(s) > 0:
result = true
proc ip6zoneVB(vb: var VBuffer): bool =
## IPv6 validateBuffer() implementation.
pathValidateBufferNoSlash(vb)
var s = ""
if vb.readSeq(s) > 0:
if s.find('/') == -1:
result = true
proc portStB(s: string, vb: var VBuffer): bool =
## Port number stringToBuffer() implementation.
@@ -178,14 +143,14 @@ proc portStB(s: string, vb: var VBuffer): bool =
port[1] = cast[byte](nport and 0xFF)
vb.writeArray(port)
result = true
except CatchableError:
except:
discard
proc portBtS(vb: var VBuffer, s: var string): bool =
## Port number bufferToString() implementation.
var port: array[2, byte]
if vb.readArray(port) == 2:
let nport = (safeConvert[uint16](port[0]) shl 8) or safeConvert[uint16](port[1])
var nport = (cast[uint16](port[0]) shl 8) or cast[uint16](port[1])
s = $nport
result = true
@@ -203,7 +168,7 @@ proc p2pStB(s: string, vb: var VBuffer): bool =
if MultiHash.decode(data, mh).isOk:
vb.writeSeq(data)
result = true
except CatchableError:
except:
discard
proc p2pBtS(vb: var VBuffer, s: var string): bool =
@@ -238,14 +203,14 @@ proc onionStB(s: string, vb: var VBuffer): bool =
address[11] = cast[byte](nport and 0xFF)
vb.writeArray(address)
result = true
except CatchableError:
except:
discard
proc onionBtS(vb: var VBuffer, s: var string): bool =
## ONION address bufferToString() implementation.
var buf: array[12, byte]
if vb.readArray(buf) == 12:
let nport = (safeConvert[uint16](buf[10]) shl 8) or safeConvert[uint16](buf[11])
var nport = (cast[uint16](buf[10]) shl 8) or cast[uint16](buf[11])
s = Base32Lower.encode(buf.toOpenArray(0, 9))
s.add(":")
s.add($nport)
@@ -272,14 +237,14 @@ proc onion3StB(s: string, vb: var VBuffer): bool =
address[36] = cast[byte](nport and 0xFF)
vb.writeArray(address)
result = true
except CatchableError:
except:
discard
proc onion3BtS(vb: var VBuffer, s: var string): bool =
## ONION address bufferToString() implementation.
var buf: array[37, byte]
if vb.readArray(buf) == 37:
var nport = (safeConvert[uint16](buf[35]) shl 8) or safeConvert[uint16](buf[36])
var nport = (cast[uint16](buf[35]) shl 8) or cast[uint16](buf[36])
s = Base32Lower.encode(buf.toOpenArray(0, 34))
s.add(":")
s.add($nport)
@@ -293,27 +258,40 @@ proc onion3VB(vb: var VBuffer): bool =
proc unixStB(s: string, vb: var VBuffer): bool =
## Unix socket name stringToBuffer() implementation.
pathStringToBuffer(s, vb)
if len(s) > 0:
vb.writeSeq(s)
result = true
proc unixBtS(vb: var VBuffer, s: var string): bool =
## Unix socket name bufferToString() implementation.
pathBufferToString(vb, s)
s = ""
if vb.readSeq(s) > 0:
result = true
proc unixVB(vb: var VBuffer): bool =
## Unix socket name validateBuffer() implementation.
pathValidateBuffer(vb)
var s = ""
if vb.readSeq(s) > 0:
result = true
proc dnsStB(s: string, vb: var VBuffer): bool =
## DNS name stringToBuffer() implementation.
pathStringToBuffer(s, vb)
if len(s) > 0:
vb.writeSeq(s)
result = true
proc dnsBtS(vb: var VBuffer, s: var string): bool =
## DNS name bufferToString() implementation.
pathBufferToStringNoSlash(vb, s)
s = ""
if vb.readSeq(s) > 0:
result = true
proc dnsVB(vb: var VBuffer): bool =
## DNS name validateBuffer() implementation.
pathValidateBufferNoSlash(vb)
var s = ""
if vb.readSeq(s) > 0:
if s.find('/') == -1:
result = true
proc mapEq*(codec: string): MaPattern =
## ``Equal`` operator for pattern
@@ -331,65 +309,149 @@ proc mapAnd*(args: varargs[MaPattern]): MaPattern =
result.args = @args
const
TranscoderIP4* =
Transcoder(stringToBuffer: ip4StB, bufferToString: ip4BtS, validateBuffer: ip4VB)
TranscoderIP6* =
Transcoder(stringToBuffer: ip6StB, bufferToString: ip6BtS, validateBuffer: ip6VB)
TranscoderIP6Zone* = Transcoder(
stringToBuffer: ip6zoneStB, bufferToString: ip6zoneBtS, validateBuffer: ip6zoneVB
TranscoderIP4* = Transcoder(
stringToBuffer: ip4StB,
bufferToString: ip4BtS,
validateBuffer: ip4VB
)
TranscoderIP6* = Transcoder(
stringToBuffer: ip6StB,
bufferToString: ip6BtS,
validateBuffer: ip6VB
)
TranscoderIP6Zone* = Transcoder(
stringToBuffer: ip6zoneStB,
bufferToString: ip6zoneBtS,
validateBuffer: ip6zoneVB
)
TranscoderUnix* = Transcoder(
stringToBuffer: unixStB,
bufferToString: unixBtS,
validateBuffer: unixVB
)
TranscoderP2P* = Transcoder(
stringToBuffer: p2pStB,
bufferToString: p2pBtS,
validateBuffer: p2pVB
)
TranscoderPort* = Transcoder(
stringToBuffer: portStB,
bufferToString: portBtS,
validateBuffer: portVB
)
TranscoderUnix* =
Transcoder(stringToBuffer: unixStB, bufferToString: unixBtS, validateBuffer: unixVB)
TranscoderP2P* =
Transcoder(stringToBuffer: p2pStB, bufferToString: p2pBtS, validateBuffer: p2pVB)
TranscoderPort* =
Transcoder(stringToBuffer: portStB, bufferToString: portBtS, validateBuffer: portVB)
TranscoderOnion* = Transcoder(
stringToBuffer: onionStB, bufferToString: onionBtS, validateBuffer: onionVB
stringToBuffer: onionStB,
bufferToString: onionBtS,
validateBuffer: onionVB
)
TranscoderOnion3* = Transcoder(
stringToBuffer: onion3StB, bufferToString: onion3BtS, validateBuffer: onion3VB
stringToBuffer: onion3StB,
bufferToString: onion3BtS,
validateBuffer: onion3VB
)
TranscoderDNS* = Transcoder(
stringToBuffer: dnsStB,
bufferToString: dnsBtS,
validateBuffer: dnsVB
)
TranscoderDNS* =
Transcoder(stringToBuffer: dnsStB, bufferToString: dnsBtS, validateBuffer: dnsVB)
ProtocolsList = [
MAProtocol(mcodec: multiCodec("ip4"), kind: Fixed, size: 4, coder: TranscoderIP4),
MAProtocol(mcodec: multiCodec("tcp"), kind: Fixed, size: 2, coder: TranscoderPort),
MAProtocol(mcodec: multiCodec("udp"), kind: Fixed, size: 2, coder: TranscoderPort),
MAProtocol(mcodec: multiCodec("ip6"), kind: Fixed, size: 16, coder: TranscoderIP6),
MAProtocol(mcodec: multiCodec("dccp"), kind: Fixed, size: 2, coder: TranscoderPort),
MAProtocol(mcodec: multiCodec("sctp"), kind: Fixed, size: 2, coder: TranscoderPort),
MAProtocol(mcodec: multiCodec("udt"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("utp"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("http"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("https"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("quic"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("quic-v1"), kind: Marker, size: 0),
MAProtocol(
mcodec: multiCodec("ip6zone"), kind: Length, size: 0, coder: TranscoderIP6Zone
mcodec: multiCodec("ip4"), kind: Fixed, size: 4,
coder: TranscoderIP4
),
MAProtocol(
mcodec: multiCodec("onion"), kind: Fixed, size: 10, coder: TranscoderOnion
mcodec: multiCodec("tcp"), kind: Fixed, size: 2,
coder: TranscoderPort
),
MAProtocol(
mcodec: multiCodec("onion3"), kind: Fixed, size: 37, coder: TranscoderOnion3
mcodec: multiCodec("udp"), kind: Fixed, size: 2,
coder: TranscoderPort
),
MAProtocol(mcodec: multiCodec("ws"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("wss"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("tls"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("ipfs"), kind: Length, size: 0, coder: TranscoderP2P),
MAProtocol(mcodec: multiCodec("p2p"), kind: Length, size: 0, coder: TranscoderP2P),
MAProtocol(mcodec: multiCodec("unix"), kind: Path, size: 0, coder: TranscoderUnix),
MAProtocol(mcodec: multiCodec("dns"), kind: Length, size: 0, coder: TranscoderDNS),
MAProtocol(mcodec: multiCodec("dns4"), kind: Length, size: 0, coder: TranscoderDNS),
MAProtocol(mcodec: multiCodec("dns6"), kind: Length, size: 0, coder: TranscoderDNS),
MAProtocol(
mcodec: multiCodec("dnsaddr"), kind: Length, size: 0, coder: TranscoderDNS
mcodec: multiCodec("ip6"), kind: Fixed, size: 16,
coder: TranscoderIP6
),
MAProtocol(mcodec: multiCodec("p2p-circuit"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("p2p-websocket-star"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("p2p-webrtc-star"), kind: Marker, size: 0),
MAProtocol(mcodec: multiCodec("p2p-webrtc-direct"), kind: Marker, size: 0),
MAProtocol(
mcodec: multiCodec("dccp"), kind: Fixed, size: 2,
coder: TranscoderPort
),
MAProtocol(
mcodec: multiCodec("sctp"), kind: Fixed, size: 2,
coder: TranscoderPort
),
MAProtocol(
mcodec: multiCodec("udt"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("utp"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("http"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("https"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("quic"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("ip6zone"), kind: Length, size: 0,
coder: TranscoderIP6Zone
),
MAProtocol(
mcodec: multiCodec("onion"), kind: Fixed, size: 10,
coder: TranscoderOnion
),
MAProtocol(
mcodec: multiCodec("onion3"), kind: Fixed, size: 37,
coder: TranscoderOnion3
),
MAProtocol(
mcodec: multiCodec("ws"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("wss"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("ipfs"), kind: Length, size: 0,
coder: TranscoderP2P
),
MAProtocol(
mcodec: multiCodec("p2p"), kind: Length, size: 0,
coder: TranscoderP2P
),
MAProtocol(
mcodec: multiCodec("unix"), kind: Path, size: 0,
coder: TranscoderUnix
),
MAProtocol(
mcodec: multiCodec("dns"), kind: Length, size: 0,
coder: TranscoderDNS
),
MAProtocol(
mcodec: multiCodec("dns4"), kind: Length, size: 0,
coder: TranscoderDNS
),
MAProtocol(
mcodec: multiCodec("dns6"), kind: Length, size: 0,
coder: TranscoderDNS
),
MAProtocol(
mcodec: multiCodec("dnsaddr"), kind: Length, size: 0,
coder: TranscoderDNS
),
MAProtocol(
mcodec: multiCodec("p2p-circuit"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("p2p-websocket-star"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("p2p-webrtc-star"), kind: Marker, size: 0
),
MAProtocol(
mcodec: multiCodec("p2p-webrtc-direct"), kind: Marker, size: 0
)
]
DNSANY* = mapEq("dns")
@@ -401,24 +463,13 @@ const
DNS* = mapOr(DNSANY, DNS4, DNS6, DNSADDR)
IP* = mapOr(IP4, IP6)
DNS_OR_IP* = mapOr(DNS, IP)
TCP_DNS* = mapAnd(DNS, mapEq("tcp"))
TCP_IP* = mapAnd(IP, mapEq("tcp"))
TCP* = mapOr(TCP_DNS, TCP_IP)
UDP_DNS* = mapAnd(DNS, mapEq("udp"))
UDP_IP* = mapAnd(IP, mapEq("udp"))
UDP* = mapOr(UDP_DNS, UDP_IP)
TCP* = mapOr(mapAnd(DNS, mapEq("tcp")), mapAnd(IP, mapEq("tcp")))
UDP* = mapOr(mapAnd(DNS, mapEq("udp")), mapAnd(IP, mapEq("udp")))
UTP* = mapAnd(UDP, mapEq("utp"))
QUIC* = mapAnd(UDP, mapEq("quic"))
UNIX* = mapEq("unix")
WS_DNS* = mapAnd(TCP_DNS, mapEq("ws"))
WS_IP* = mapAnd(TCP_IP, mapEq("ws"))
WS* = mapAnd(TCP, mapEq("ws"))
TLS_WS* = mapOr(mapEq("wss"), mapAnd(mapEq("tls"), mapEq("ws")))
WSS_DNS* = mapAnd(TCP_DNS, TLS_WS)
WSS_IP* = mapAnd(TCP_IP, TLS_WS)
WSS* = mapAnd(TCP, TLS_WS)
WebSockets_DNS* = mapOr(WS_DNS, WSS_DNS)
WebSockets_IP* = mapOr(WS_IP, WSS_IP)
WSS* = mapAnd(TCP, mapEq("wss"))
WebSockets* = mapOr(WS, WSS)
Onion3* = mapEq("onion3")
TcpOnion3* = mapAnd(TCP, Onion3)
@@ -432,24 +483,31 @@ const
IPFS* = mapAnd(Reliable, P2PPattern)
HTTP* = mapOr(
mapAnd(TCP, mapEq("http")), mapAnd(IP, mapEq("http")), mapAnd(DNS, mapEq("http"))
mapAnd(TCP, mapEq("http")),
mapAnd(IP, mapEq("http")),
mapAnd(DNS, mapEq("http"))
)
HTTPS* = mapOr(
mapAnd(TCP, mapEq("https")), mapAnd(IP, mapEq("https")), mapAnd(DNS, mapEq("https"))
mapAnd(TCP, mapEq("https")),
mapAnd(IP, mapEq("https")),
mapAnd(DNS, mapEq("https"))
)
WebRTCDirect* = mapOr(
mapAnd(HTTP, mapEq("p2p-webrtc-direct")), mapAnd(HTTPS, mapEq("p2p-webrtc-direct"))
mapAnd(HTTP, mapEq("p2p-webrtc-direct")),
mapAnd(HTTPS, mapEq("p2p-webrtc-direct"))
)
CircuitRelay* = mapEq("p2p-circuit")
proc initMultiAddressCodeTable(): Table[MultiCodec, MAProtocol] {.compileTime.} =
proc initMultiAddressCodeTable(): Table[MultiCodec,
MAProtocol] {.compileTime.} =
for item in ProtocolsList:
result[item.mcodec] = item
const CodeAddresses = initMultiAddressCodeTable()
const
CodeAddresses = initMultiAddressCodeTable()
proc trimRight(s: string, ch: char): string =
## Consume trailing characters ``ch`` from string ``s`` and return result.
@@ -459,7 +517,7 @@ proc trimRight(s: string, ch: char): string =
inc(m)
else:
break
result = s[0 .. (s.high - m)]
result = s[0..(s.high - m)]
proc protoCode*(ma: MultiAddress): MaResult[MultiCodec] =
## Returns MultiAddress ``ma`` protocol code.
@@ -487,7 +545,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] =
else:
ok($(proto.mcodec))
proc protoArgument*(ma: MultiAddress, value: var openArray[byte]): MaResult[int] =
proc protoArgument*(ma: MultiAddress,
value: var openArray[byte]): MaResult[int] =
## Returns MultiAddress ``ma`` protocol argument value.
##
## If current MultiAddress do not have argument value, then result will be
@@ -506,7 +565,7 @@ proc protoArgument*(ma: MultiAddress, value: var openArray[byte]): MaResult[int]
if proto.kind == Fixed:
res = proto.size
if len(value) >= res and
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
err("multiaddress: Decoding protocol error")
else:
ok(res)
@@ -527,7 +586,7 @@ proc protoAddress*(ma: MultiAddress): MaResult[seq[byte]] =
## If current MultiAddress do not have argument value, then result array will
## be empty.
var buffer = newSeq[byte](len(ma.data.buffer))
let res = ?protoArgument(ma, buffer)
let res = ? protoArgument(ma, buffer)
buffer.setLen(res)
ok(buffer)
@@ -546,8 +605,7 @@ proc getPart(ma: MultiAddress, index: int): MaResult[MultiAddress] =
var res: MultiAddress
res.data = initVBuffer()
if index < 0:
return err("multiaddress: negative index gived to getPart")
if index < 0: return err("multiaddress: negative index gived to getPart")
while offset <= index:
if vb.data.readVarint(header) == -1:
@@ -556,6 +614,7 @@ proc getPart(ma: MultiAddress, index: int): MaResult[MultiAddress] =
let proto = CodeAddresses.getOrDefault(MultiCodec(header))
if proto.kind == None:
return err("multiaddress: Unsupported protocol '" & $header & "'")
elif proto.kind == Fixed:
data.setLen(proto.size)
if vb.data.readArray(data) != proto.size:
@@ -582,27 +641,22 @@ proc getPart(ma: MultiAddress, index: int): MaResult[MultiAddress] =
proc getParts[U, V](ma: MultiAddress, slice: HSlice[U, V]): MaResult[MultiAddress] =
when slice.a is BackwardsIndex or slice.b is BackwardsIndex:
let maLength = ?len(ma)
let maLength = ? len(ma)
template normalizeIndex(index): int =
when index is BackwardsIndex:
maLength - int(index)
else:
int(index)
when index is BackwardsIndex: maLength - int(index)
else: int(index)
let
indexStart = normalizeIndex(slice.a)
indexEnd = normalizeIndex(slice.b)
var res: MultiAddress
for i in indexStart .. indexEnd:
?res.append(?ma[i])
for i in indexStart..indexEnd:
? res.append(? ma[i])
ok(res)
proc `[]`*(
ma: MultiAddress, i: int | BackwardsIndex
): MaResult[MultiAddress] {.inline.} =
proc `[]`*(ma: MultiAddress, i: int | BackwardsIndex): MaResult[MultiAddress] {.inline.} =
## Returns part with index ``i`` of MultiAddress ``ma``.
when i is BackwardsIndex:
let maLength = ?len(ma)
let maLength = ? len(ma)
ma.getPart(maLength - int(i))
else:
ma.getPart(i)
@@ -626,7 +680,9 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] =
let proto = CodeAddresses.getOrDefault(MultiCodec(header))
if proto.kind == None:
yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'")
yield err(MaResult[MultiAddress], "Unsupported protocol '" &
$header & "'")
elif proto.kind == Fixed:
data.setLen(proto.size)
if vb.data.readArray(data) != proto.size:
@@ -648,8 +704,7 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] =
proc len*(ma: MultiAddress): MaResult[int] =
var counter: int
for part in ma:
if part.isErr:
return err(part.error)
if part.isErr: return err(part.error)
counter.inc()
ok(counter)
@@ -662,7 +717,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} =
return ok(true)
ok(false)
proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
proc `[]`*(ma: MultiAddress,
codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
## Returns partial MultiAddress with MultiCodec ``codec`` and present in
## MultiAddress ``ma``.
for item in ma.items:
@@ -687,12 +743,13 @@ proc toString*(value: MultiAddress): MaResult[string] =
return err("multiaddress: Unsupported protocol '" & $header & "'")
if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.bufferToString):
return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder")
return err("multiaddress: Missing protocol '" & $(proto.mcodec) &
"' coder")
if not proto.coder.bufferToString(vb.data, part):
return err("multiaddress: Decoding protocol error")
parts.add($(proto.mcodec))
if len(part) > 0 and (proto.kind == Path) and (part[0] == '/'):
parts.add(part[1 ..^ 1])
if proto.kind == Path and part[0] == '/':
parts.add(part[1..^1])
else:
parts.add(part)
elif proto.kind == Marker:
@@ -701,13 +758,11 @@ proc toString*(value: MultiAddress): MaResult[string] =
res = "/" & parts.join("/")
ok(res)
proc `$`*(value: MultiAddress): string =
proc `$`*(value: MultiAddress): string {.raises: [Defect].} =
## Return string representation of MultiAddress ``value``.
let s = value.toString()
if s.isErr:
s.error
else:
s[]
if s.isErr: s.error
else: s[]
proc protocols*(value: MultiAddress): MaResult[seq[MultiCodec]] =
## Returns list of protocol codecs inside of MultiAddress ``value``.
@@ -724,9 +779,8 @@ proc write*(vb: var VBuffer, ma: MultiAddress) {.inline.} =
## Write MultiAddress value ``ma`` to buffer ``vb``.
vb.writeArray(ma.data.buffer)
proc encode*(
mbtype: typedesc[MultiBase], encoding: string, ma: MultiAddress
): string {.inline.} =
proc encode*(mbtype: typedesc[MultiBase], encoding: string,
ma: MultiAddress): string {.inline.} =
## Get MultiBase encoded representation of ``ma`` using encoding
## ``encoding``.
result = MultiBase.encode(encoding, ma.data.buffer)
@@ -753,8 +807,8 @@ proc validate*(ma: MultiAddress): bool =
result = true
proc init*(
mtype: typedesc[MultiAddress], protocol: MultiCodec, value: openArray[byte] = []
): MaResult[MultiAddress] =
mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: openArray[byte] = []): MaResult[MultiAddress] =
## Initialize MultiAddress object from protocol id ``protocol`` and array
## of bytes ``value``.
let proto = CodeAddresses.getOrDefault(protocol)
@@ -784,21 +838,19 @@ proc init*(
of None:
raiseAssert "None checked above"
proc init*(
mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerId
): MaResult[MultiAddress] {.inline.} =
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: PeerId): MaResult[MultiAddress] {.inline.} =
## Initialize MultiAddress object from protocol id ``protocol`` and peer id
## ``value``.
init(mtype, protocol, value.data)
proc init*(
mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int
): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: int): MaResult[MultiAddress] =
## Initialize MultiAddress object from protocol id ``protocol`` and integer
## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``,
## ``dccp`` and ``sctp`` MultiAddresses.
var allowed =
[multiCodec("tcp"), multiCodec("udp"), multiCodec("dccp"), multiCodec("sctp")]
var allowed = [multiCodec("tcp"), multiCodec("udp"), multiCodec("dccp"),
multiCodec("sctp")]
if protocol notin allowed:
err("multiaddress: Incorrect protocol for integer value")
else:
@@ -818,10 +870,9 @@ proc getProtocol(name: string): MAProtocol {.inline.} =
if mc != InvalidMultiCodec:
result = CodeAddresses.getOrDefault(mc)
proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress],
value: string): MaResult[MultiAddress] =
## Initialize MultiAddress object from string representation ``value``.
if len(value) == 0 or value == "/":
return err("multiaddress: Address must not be empty!")
var parts = value.trimRight('/').split('/')
if len(parts[0]) != 0:
err("multiaddress: Invalid MultiAddress, must start with `/`")
@@ -837,7 +888,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
else:
if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.stringToBuffer):
return err("multiaddress: Missing protocol '" & part & "' transcoder")
return err("multiaddress: Missing protocol '" &
part & "' transcoder")
if offset + 1 >= len(parts):
return err("multiaddress: Missing protocol '" & part & "' argument")
@@ -846,15 +898,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.write(proto.mcodec)
let res = proto.coder.stringToBuffer(parts[offset + 1], res.data)
if not res:
return err(
"multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`"
)
return err("multiaddress: Error encoding `" & part & "/" &
parts[offset + 1] & "`")
offset += 2
elif proto.kind == Path:
var path = "/" & (parts[(offset + 1) ..^ 1].join("/"))
var path = "/" & (parts[(offset + 1)..^1].join("/"))
res.data.write(proto.mcodec)
if not proto.coder.stringToBuffer(path, res.data):
return err("multiaddress: Error encoding `" & part & "/" & path & "`")
return err("multiaddress: Error encoding `" & part & "/" &
path & "`")
break
elif proto.kind == Marker:
@@ -863,12 +916,11 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.finish()
ok(res)
proc init*(
mtype: typedesc[MultiAddress], data: openArray[byte]
): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress],
data: openArray[byte]): MaResult[MultiAddress] =
## Initialize MultiAddress with array of bytes ``data``.
if len(data) == 0:
err("multiaddress: Address must not be empty!")
err("multiaddress: Address could not be empty!")
else:
var res: MultiAddress
res.data = initVBuffer()
@@ -883,55 +935,38 @@ proc init*(mtype: typedesc[MultiAddress]): MultiAddress =
## Initialize empty MultiAddress.
result.data = initVBuffer()
proc init*(
mtype: typedesc[MultiAddress],
address: IpAddress,
protocol: IpTransportProtocol,
port: Port,
): MultiAddress =
proc init*(mtype: typedesc[MultiAddress], address: ValidIpAddress,
protocol: IpTransportProtocol, port: Port): MultiAddress =
var res: MultiAddress
res.data = initVBuffer()
let
networkProto =
case address.family
of IpAddressFamily.IPv4:
getProtocol("ip4")
of IpAddressFamily.IPv6:
getProtocol("ip6")
networkProto = case address.family
of IpAddressFamily.IPv4: getProtocol("ip4")
of IpAddressFamily.IPv6: getProtocol("ip6")
transportProto =
case protocol
of tcpProtocol:
getProtocol("tcp")
of udpProtocol:
getProtocol("udp")
transportProto = case protocol
of tcpProtocol: getProtocol("tcp")
of udpProtocol: getProtocol("udp")
res.data.write(networkProto.mcodec)
case address.family
of IpAddressFamily.IPv4:
res.data.writeArray(address.address_v4)
of IpAddressFamily.IPv6:
res.data.writeArray(address.address_v6)
of IpAddressFamily.IPv4: res.data.writeArray(address.address_v4)
of IpAddressFamily.IPv6: res.data.writeArray(address.address_v6)
res.data.write(transportProto.mcodec)
res.data.writeArray(toBytesBE(uint16(port)))
res.data.finish()
res
proc init*(
mtype: typedesc[MultiAddress], address: TransportAddress, protocol = IPPROTO_TCP
): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress], address: TransportAddress,
protocol = IPPROTO_TCP): MaResult[MultiAddress] =
## Initialize MultiAddress using chronos.TransportAddress (IPv4/IPv6/Unix)
## and protocol information (UDP/TCP).
var res: MultiAddress
res.data = initVBuffer()
let protoProto =
case protocol
of IPPROTO_TCP:
getProtocol("tcp")
of IPPROTO_UDP:
getProtocol("udp")
else:
default(MAProtocol)
let protoProto = case protocol
of IPPROTO_TCP: getProtocol("tcp")
of IPPROTO_UDP: getProtocol("udp")
else: default(MAProtocol)
if protoProto.size == 0:
return err("multiaddress: protocol should be either TCP or UDP")
if address.family == AddressFamily.IPv4:
@@ -970,7 +1005,8 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] =
else:
ok()
proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [LPError].} =
proc `&`*(m1, m2: MultiAddress): MultiAddress {.
raises: [Defect, LPError].} =
## Concatenates two addresses ``m1`` and ``m2``, and returns result.
##
## This procedure performs validation of concatenated result and can raise
@@ -979,7 +1015,8 @@ proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [LPError].} =
concat(m1, m2).tryGet()
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [LPError].} =
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.
raises: [Defect, LPError].} =
## Concatenates two addresses ``m1`` and ``m2``.
##
## This procedure performs validation of concatenated result and can raise
@@ -1001,12 +1038,13 @@ proc matchPart(pat: MaPattern, protos: seq[MultiCodec]): MaPatResult =
let res = a.matchPart(pcs)
if res.flag:
#Greedy Or
if result.flag == false or result.rem.len > res.rem.len:
if result.flag == false or
result.rem.len > res.rem.len:
result = res
elif pat.operator == And:
if len(pcs) < len(pat.args):
return MaPatResult(flag: false, rem: empty)
for i in 0 ..< len(pat.args):
for i in 0..<len(pat.args):
let res = pat.args[i].matchPart(pcs)
if not res.flag:
return MaPatResult(flag: false, rem: res.rem)
@@ -1016,23 +1054,25 @@ proc matchPart(pat: MaPattern, protos: seq[MultiCodec]): MaPatResult =
if len(pcs) == 0:
return MaPatResult(flag: false, rem: empty)
if pcs[0] == pat.value:
return MaPatResult(flag: true, rem: pcs[1 ..^ 1])
return MaPatResult(flag: true, rem: pcs[1..^1])
result = MaPatResult(flag: false, rem: empty)
proc match*(pat: MaPattern, address: MultiAddress): bool =
## Match full ``address`` using pattern ``pat`` and return ``true`` if
## ``address`` satisfies pattern.
let protos = address.protocols().valueOr:
let protos = address.protocols()
if protos.isErr():
return false
let res = matchPart(pat, protos)
let res = matchPart(pat, protos.get())
res.flag and (len(res.rem) == 0)
proc matchPartial*(pat: MaPattern, address: MultiAddress): bool =
## Match prefix part of ``address`` using pattern ``pat`` and return
## ``true`` if ``address`` starts with pattern.
let protos = address.protocols().valueOr:
let protos = address.protocols()
if protos.isErr():
return false
let res = matchPart(pat, protos)
let res = matchPart(pat, protos.get())
res.flag
proc `$`*(pat: MaPattern): string =
@@ -1053,41 +1093,35 @@ proc bytes*(value: MultiAddress): seq[byte] =
proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
write(pb, field, value.data.buffer)
proc getField*(
pb: ProtoBuffer, field: int, value: var MultiAddress
): ProtoResult[bool] {.inline.} =
proc getField*(pb: ProtoBuffer, field: int,
value: var MultiAddress): ProtoResult[bool] {.
inline.} =
var buffer: seq[byte]
let res = ?pb.getField(field, buffer)
if not (res):
let res = ? pb.getField(field, buffer)
if not(res):
ok(false)
else:
value = MultiAddress.init(buffer).valueOr:
return err(ProtoError.IncorrectBlob)
ok(true)
let ma = MultiAddress.init(buffer)
if ma.isOk():
value = ma.get()
ok(true)
else:
err(ProtoError.IncorrectBlob)
proc getRepeatedField*(
pb: ProtoBuffer, field: int, value: var seq[MultiAddress]
): ProtoResult[bool] {.inline.} =
## Read repeated field from protobuf message. ``field`` is field number.
## If the message is malformed, an error is returned. If field is not present
## in message, then ``ok(false)`` is returned and value is empty. If field is
## present, but no items could be parsed, then
## ``err(ProtoError.IncorrectBlob)`` is returned and value is empty.
## If field is present and some item could be parsed, then ``true`` is
## returned and value contains the parsed values.
proc getRepeatedField*(pb: ProtoBuffer, field: int,
value: var seq[MultiAddress]): ProtoResult[bool] {.
inline.} =
var items: seq[seq[byte]]
value.setLen(0)
let res = ?pb.getRepeatedField(field, items)
if not (res):
let res = ? pb.getRepeatedField(field, items)
if not(res):
ok(false)
else:
for item in items:
let ma = MultiAddress.init(item).valueOr:
debug "Unsupported MultiAddress in blob", ma = item
continue
value.add(ma)
if value.len == 0:
err(ProtoError.IncorrectBlob)
else:
ok(true)
let ma = MultiAddress.init(item)
if ma.isOk():
value.add(ma.get())
else:
value.setLen(0)
return err(ProtoError.IncorrectBlob)
ok(true)

View File

@@ -13,39 +13,36 @@
## 1. base32z
##
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import tables
import stew/[base32, base58, base64, results]
type
MultiBaseStatus* {.pure.} = enum
Error
Success
Overrun
Incorrect
BadCodec
NotSupported
Error, Success, Overrun, Incorrect, BadCodec, NotSupported
MultiBase* = object
MBCodeSize = proc(length: int): int {.nimcall, gcsafe, noSideEffect, raises: [].}
MBCodeSize = proc(length: int): int {.nimcall, gcsafe, noSideEffect, raises: [Defect].}
MBCodec = object
code: char
name: string
encr: proc(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus {.nimcall, gcsafe, noSideEffect, raises: [].}
decr: proc(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus {.nimcall, gcsafe, noSideEffect, raises: [].}
encr: proc(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus {.nimcall, gcsafe, noSideEffect, raises: [Defect].}
decr: proc(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus {.nimcall, gcsafe, noSideEffect, raises: [Defect].}
encl: MBCodeSize
decl: MBCodeSize
proc idd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc idd(inbytes: openArray[char], outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
let length = len(inbytes)
if length > len(outbytes):
outlen = length
@@ -55,9 +52,9 @@ proc idd(
outlen = length
result = MultiBaseStatus.Success
proc ide(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc ide(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
let length = len(inbytes)
if length > len(outbytes):
outlen = length
@@ -67,37 +64,31 @@ proc ide(
outlen = length
result = MultiBaseStatus.Success
proc idel(length: int): int =
length
proc idel(length: int): int = length
proc iddl(length: int): int = length
proc iddl(length: int): int =
length
proc b16d(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b16d(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
discard
proc b16e(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b16e(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
discard
proc b16ud(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b16ud(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
discard
proc b16ue(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b16ue(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
discard
proc b16el(length: int): int =
length shl 1
proc b16dl(length: int): int =
(length + 1) div 2
proc b16el(length: int): int = length shl 1
proc b16dl(length: int): int = (length + 1) div 2
proc b32ce(r: Base32Status): MultiBaseStatus {.inline.} =
result = MultiBaseStatus.Error
@@ -126,253 +117,218 @@ proc b64ce(r: Base64Status): MultiBaseStatus {.inline.} =
elif r == Base64Status.Success:
result = MultiBaseStatus.Success
proc b32hd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32hd(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32Lower.decode(inbytes, outbytes, outlen))
proc b32he(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32he(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32Lower.encode(inbytes, outbytes, outlen))
proc b32hud(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32hud(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32Upper.decode(inbytes, outbytes, outlen))
proc b32hue(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32hue(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32Upper.encode(inbytes, outbytes, outlen))
proc b32hpd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32hpd(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32LowerPad.decode(inbytes, outbytes, outlen))
proc b32hpe(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32hpe(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32LowerPad.encode(inbytes, outbytes, outlen))
proc b32hpud(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32hpud(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32UpperPad.decode(inbytes, outbytes, outlen))
proc b32hpue(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32hpue(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(HexBase32UpperPad.encode(inbytes, outbytes, outlen))
proc b32d(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32d(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32Lower.decode(inbytes, outbytes, outlen))
proc b32e(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32e(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32Lower.encode(inbytes, outbytes, outlen))
proc b32ud(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32ud(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32Upper.decode(inbytes, outbytes, outlen))
proc b32ue(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32ue(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32Upper.encode(inbytes, outbytes, outlen))
proc b32pd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32pd(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32LowerPad.decode(inbytes, outbytes, outlen))
proc b32pe(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32pe(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32LowerPad.encode(inbytes, outbytes, outlen))
proc b32pud(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b32pud(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32UpperPad.decode(inbytes, outbytes, outlen))
proc b32pue(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b32pue(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b32ce(Base32UpperPad.encode(inbytes, outbytes, outlen))
proc b32el(length: int): int =
Base32Lower.encodedLength(length)
proc b32el(length: int): int = Base32Lower.encodedLength(length)
proc b32dl(length: int): int = Base32Lower.decodedLength(length)
proc b32pel(length: int): int = Base32LowerPad.encodedLength(length)
proc b32pdl(length: int): int = Base32LowerPad.decodedLength(length)
proc b32dl(length: int): int =
Base32Lower.decodedLength(length)
proc b32pel(length: int): int =
Base32LowerPad.encodedLength(length)
proc b32pdl(length: int): int =
Base32LowerPad.decodedLength(length)
proc b58fd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b58fd(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b58ce(FLCBase58.decode(inbytes, outbytes, outlen))
proc b58fe(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b58fe(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b58ce(FLCBase58.encode(inbytes, outbytes, outlen))
proc b58bd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b58bd(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b58ce(BTCBase58.decode(inbytes, outbytes, outlen))
proc b58be(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b58be(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b58ce(BTCBase58.encode(inbytes, outbytes, outlen))
proc b58el(length: int): int =
Base58.encodedLength(length)
proc b58el(length: int): int = Base58.encodedLength(length)
proc b58dl(length: int): int = Base58.decodedLength(length)
proc b58dl(length: int): int =
Base58.decodedLength(length)
proc b64el(length: int): int = Base64.encodedLength(length)
proc b64dl(length: int): int = Base64.decodedLength(length)
proc b64pel(length: int): int = Base64Pad.encodedLength(length)
proc b64pdl(length: int): int = Base64Pad.decodedLength(length)
proc b64el(length: int): int =
Base64.encodedLength(length)
proc b64dl(length: int): int =
Base64.decodedLength(length)
proc b64pel(length: int): int =
Base64Pad.encodedLength(length)
proc b64pdl(length: int): int =
Base64Pad.decodedLength(length)
proc b64e(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b64e(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64.encode(inbytes, outbytes, outlen))
proc b64d(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b64d(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64.decode(inbytes, outbytes, outlen))
proc b64pe(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b64pe(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64Pad.encode(inbytes, outbytes, outlen))
proc b64pd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b64pd(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64Pad.decode(inbytes, outbytes, outlen))
proc b64ue(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b64ue(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64Url.encode(inbytes, outbytes, outlen))
proc b64ud(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b64ud(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64Url.decode(inbytes, outbytes, outlen))
proc b64upe(
inbytes: openArray[byte], outbytes: var openArray[char], outlen: var int
): MultiBaseStatus =
proc b64upe(inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64UrlPad.encode(inbytes, outbytes, outlen))
proc b64upd(
inbytes: openArray[char], outbytes: var openArray[byte], outlen: var int
): MultiBaseStatus =
proc b64upd(inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int): MultiBaseStatus =
result = b64ce(Base64UrlPad.decode(inbytes, outbytes, outlen))
const MultiBaseCodecs = [
MBCodec(
name: "identity", code: chr(0x00), decr: idd, encr: ide, decl: iddl, encl: idel
),
MBCodec(name: "base1", code: '1'),
MBCodec(name: "base2", code: '0'),
MBCodec(name: "base8", code: '7'),
MBCodec(name: "base10", code: '9'),
MBCodec(name: "base16", code: 'f', decr: b16d, encr: b16e, decl: b16dl, encl: b16el),
MBCodec(
name: "base16upper", code: 'F', decr: b16ud, encr: b16ue, decl: b16dl, encl: b16el
),
MBCodec(
name: "base32hex", code: 'v', decr: b32hd, encr: b32he, decl: b32dl, encl: b32el
),
MBCodec(
name: "base32hexupper",
code: 'V',
decr: b32hud,
encr: b32hue,
decl: b32dl,
encl: b32el,
),
MBCodec(
name: "base32hexpad",
code: 't',
decr: b32hpd,
encr: b32hpe,
decl: b32pdl,
encl: b32pel,
),
MBCodec(
name: "base32hexpadupper",
code: 'T',
decr: b32hpud,
encr: b32hpue,
decl: b32pdl,
encl: b32pel,
),
MBCodec(name: "base32", code: 'b', decr: b32d, encr: b32e, decl: b32dl, encl: b32el),
MBCodec(
name: "base32upper", code: 'B', decr: b32ud, encr: b32ue, decl: b32dl, encl: b32el
),
MBCodec(
name: "base32pad", code: 'c', decr: b32pd, encr: b32pe, decl: b32pdl, encl: b32pel
),
MBCodec(
name: "base32padupper",
code: 'C',
decr: b32pud,
encr: b32pue,
decl: b32pdl,
encl: b32pel,
),
MBCodec(name: "base32z", code: 'h'),
MBCodec(
name: "base58flickr", code: 'Z', decr: b58fd, encr: b58fe, decl: b58dl, encl: b58el
),
MBCodec(
name: "base58btc", code: 'z', decr: b58bd, encr: b58be, decl: b58dl, encl: b58el
),
MBCodec(name: "base64", code: 'm', decr: b64d, encr: b64e, decl: b64dl, encl: b64el),
MBCodec(
name: "base64pad", code: 'M', decr: b64pd, encr: b64pe, decl: b64pdl, encl: b64pel
),
MBCodec(
name: "base64url", code: 'u', decr: b64ud, encr: b64ue, decl: b64dl, encl: b64el
),
MBCodec(
name: "base64urlpad",
code: 'U',
decr: b64upd,
encr: b64upe,
decl: b64pdl,
encl: b64pel,
),
]
const
MultiBaseCodecs = [
MBCodec(name: "identity", code: chr(0x00),
decr: idd, encr: ide, decl: iddl, encl: idel
),
MBCodec(name: "base1", code: '1'),
MBCodec(name: "base2", code: '0'),
MBCodec(name: "base8", code: '7'),
MBCodec(name: "base10", code: '9'),
MBCodec(name: "base16", code: 'f',
decr: b16d, encr: b16e, decl: b16dl, encl: b16el
),
MBCodec(name: "base16upper", code: 'F',
decr: b16ud, encr: b16ue, decl: b16dl, encl: b16el
),
MBCodec(name: "base32hex", code: 'v',
decr: b32hd, encr: b32he, decl: b32dl, encl: b32el
),
MBCodec(name: "base32hexupper", code: 'V',
decr: b32hud, encr: b32hue, decl: b32dl, encl: b32el
),
MBCodec(name: "base32hexpad", code: 't',
decr: b32hpd, encr: b32hpe, decl: b32pdl, encl: b32pel
),
MBCodec(name: "base32hexpadupper", code: 'T',
decr: b32hpud, encr: b32hpue, decl: b32pdl, encl: b32pel
),
MBCodec(name: "base32", code: 'b',
decr: b32d, encr: b32e, decl: b32dl, encl: b32el
),
MBCodec(name: "base32upper", code: 'B',
decr: b32ud, encr: b32ue, decl: b32dl, encl: b32el
),
MBCodec(name: "base32pad", code: 'c',
decr: b32pd, encr: b32pe, decl: b32pdl, encl: b32pel
),
MBCodec(name: "base32padupper", code: 'C',
decr: b32pud, encr: b32pue, decl: b32pdl, encl: b32pel
),
MBCodec(name: "base32z", code: 'h'),
MBCodec(name: "base58flickr", code: 'Z',
decr: b58fd, encr: b58fe, decl: b58dl, encl: b58el
),
MBCodec(name: "base58btc", code: 'z',
decr: b58bd, encr: b58be, decl: b58dl, encl: b58el
),
MBCodec(name: "base64", code: 'm',
decr: b64d, encr: b64e, decl: b64dl, encl: b64el
),
MBCodec(name: "base64pad", code: 'M',
decr: b64pd, encr: b64pe, decl: b64pdl, encl: b64pel
),
MBCodec(name: "base64url", code: 'u',
decr: b64ud, encr: b64ue, decl: b64dl, encl: b64el
),
MBCodec(name: "base64urlpad", code: 'U',
decr: b64upd, encr: b64upe, decl: b64pdl, encl: b64pel
)
]
proc initMultiBaseCodeTable(): Table[char, MBCodec] {.compileTime.} =
for item in MultiBaseCodecs:
@@ -386,7 +342,8 @@ const
CodeMultiBases = initMultiBaseCodeTable()
NameMultiBases = initMultiBaseNameTable()
proc encodedLength*(mbtype: typedesc[MultiBase], encoding: string, length: int): int =
proc encodedLength*(mbtype: typedesc[MultiBase], encoding: string,
length: int): int =
## Return estimated size of buffer to store MultiBase encoded value with
## encoding ``encoding`` of length ``length``.
##
@@ -401,7 +358,8 @@ proc encodedLength*(mbtype: typedesc[MultiBase], encoding: string, length: int):
else:
result = mb.encl(length) + 1
proc decodedLength*(mbtype: typedesc[MultiBase], encoding: char, length: int): int =
proc decodedLength*(mbtype: typedesc[MultiBase], encoding: char,
length: int): int =
## Return estimated size of buffer to store MultiBase decoded value with
## encoding character ``encoding`` of length ``length``.
let mb = CodeMultiBases.getOrDefault(encoding)
@@ -413,13 +371,9 @@ proc decodedLength*(mbtype: typedesc[MultiBase], encoding: char, length: int): i
else:
result = mb.decl(length - 1)
proc encode*(
mbtype: typedesc[MultiBase],
encoding: string,
inbytes: openArray[byte],
outbytes: var openArray[char],
outlen: var int,
): MultiBaseStatus =
proc encode*(mbtype: typedesc[MultiBase], encoding: string,
inbytes: openArray[byte], outbytes: var openArray[char],
outlen: var int): MultiBaseStatus =
## Encode array ``inbytes`` using MultiBase encoding scheme ``encoding`` and
## store encoded value to ``outbytes``.
##
@@ -441,7 +395,8 @@ proc encode*(
if isNil(mb.encr) or isNil(mb.encl):
return MultiBaseStatus.NotSupported
if len(outbytes) > 1:
result = mb.encr(inbytes, outbytes.toOpenArray(1, outbytes.high), outlen)
result = mb.encr(inbytes, outbytes.toOpenArray(1, outbytes.high),
outlen)
if result == MultiBaseStatus.Overrun:
outlen += 1
elif result == MultiBaseStatus.Success:
@@ -456,12 +411,8 @@ proc encode*(
result = MultiBaseStatus.Overrun
outlen = mb.encl(len(inbytes)) + 1
proc decode*(
mbtype: typedesc[MultiBase],
inbytes: openArray[char],
outbytes: var openArray[byte],
outlen: var int,
): MultiBaseStatus =
proc decode*(mbtype: typedesc[MultiBase], inbytes: openArray[char],
outbytes: var openArray[byte], outlen: var int): MultiBaseStatus =
## Decode array ``inbytes`` using MultiBase encoding and store decoded value
## to ``outbytes``.
##
@@ -490,9 +441,8 @@ proc decode*(
else:
result = mb.decr(inbytes.toOpenArray(1, length - 1), outbytes, outlen)
proc encode*(
mbtype: typedesc[MultiBase], encoding: string, inbytes: openArray[byte]
): Result[string, string] =
proc encode*(mbtype: typedesc[MultiBase], encoding: string,
inbytes: openArray[byte]): Result[string, string] =
## Encode array ``inbytes`` using MultiBase encoding scheme ``encoding`` and
## return encoded string.
let length = len(inbytes)
@@ -515,9 +465,7 @@ proc encode*(
buffer[0] = mb.code
ok(buffer)
proc decode*(
mbtype: typedesc[MultiBase], inbytes: openArray[char]
): Result[seq[byte], string] =
proc decode*(mbtype: typedesc[MultiBase], inbytes: openArray[char]): Result[seq[byte], string] =
## Decode MultiBase encoded array ``inbytes`` and return decoded sequence of
## bytes.
let length = len(inbytes)
@@ -534,7 +482,8 @@ proc decode*(
else:
var buffer = newSeq[byte](mb.decl(length - 1))
var outlen = 0
let res = mb.decr(inbytes.toOpenArray(1, length - 1), buffer, outlen)
let res = mb.decr(inbytes.toOpenArray(1, length - 1),
buffer, outlen)
if res != MultiBaseStatus.Success:
err("multibase: Decoding error [" & $res & "]")
else:

View File

@@ -9,13 +9,18 @@
## This module implements MultiCodec.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import tables, hashes
import vbuffer
import varint, vbuffer
import stew/results
export results
{.deadCodeElim: on.}
## List of officially supported codecs can BE found here
## https://github.com/multiformats/multicodec/blob/master/table.csv
const MultiCodecList = [
@@ -49,325 +54,132 @@ const MultiCodecList = [
("keccak-384", 0x1C),
("keccak-512", 0x1D),
("murmur3", 0x22),
("blake2b-8", 0xB201),
("blake2b-16", 0xB202),
("blake2b-24", 0xB203),
("blake2b-32", 0xB204),
("blake2b-40", 0xB205),
("blake2b-48", 0xB206),
("blake2b-56", 0xB207),
("blake2b-64", 0xB208),
("blake2b-72", 0xB209),
("blake2b-80", 0xB20A),
("blake2b-88", 0xB20B),
("blake2b-96", 0xB20C),
("blake2b-104", 0xB20D),
("blake2b-112", 0xB20E),
("blake2b-120", 0xB20F),
("blake2b-128", 0xB210),
("blake2b-136", 0xB211),
("blake2b-144", 0xB212),
("blake2b-152", 0xB213),
("blake2b-160", 0xB214),
("blake2b-168", 0xB215),
("blake2b-176", 0xB216),
("blake2b-184", 0xB217),
("blake2b-192", 0xB218),
("blake2b-200", 0xB219),
("blake2b-208", 0xB21A),
("blake2b-216", 0xB21B),
("blake2b-224", 0xB21C),
("blake2b-232", 0xB21D),
("blake2b-240", 0xB21E),
("blake2b-248", 0xB21F),
("blake2b-256", 0xB220),
("blake2b-264", 0xB221),
("blake2b-272", 0xB222),
("blake2b-280", 0xB223),
("blake2b-288", 0xB224),
("blake2b-296", 0xB225),
("blake2b-304", 0xB226),
("blake2b-312", 0xB227),
("blake2b-320", 0xB228),
("blake2b-328", 0xB229),
("blake2b-336", 0xB22A),
("blake2b-344", 0xB22B),
("blake2b-352", 0xB22C),
("blake2b-360", 0xB22D),
("blake2b-368", 0xB22E),
("blake2b-376", 0xB22F),
("blake2b-384", 0xB230),
("blake2b-392", 0xB231),
("blake2b-400", 0xB232),
("blake2b-408", 0xB233),
("blake2b-416", 0xB234),
("blake2b-424", 0xB235),
("blake2b-432", 0xB236),
("blake2b-440", 0xB237),
("blake2b-448", 0xB238),
("blake2b-456", 0xB239),
("blake2b-464", 0xB23A),
("blake2b-472", 0xB23B),
("blake2b-480", 0xB23C),
("blake2b-488", 0xB23D),
("blake2b-496", 0xB23E),
("blake2b-504", 0xB23F),
("blake2b-512", 0xB240),
("blake2s-8", 0xB241),
("blake2s-16", 0xB242),
("blake2s-24", 0xB243),
("blake2s-32", 0xB244),
("blake2s-40", 0xB245),
("blake2s-48", 0xB246),
("blake2s-56", 0xB247),
("blake2s-64", 0xB248),
("blake2s-72", 0xB249),
("blake2s-80", 0xB24A),
("blake2s-88", 0xB24B),
("blake2s-96", 0xB24C),
("blake2s-104", 0xB24D),
("blake2s-112", 0xB24E),
("blake2s-120", 0xB24F),
("blake2s-128", 0xB250),
("blake2s-136", 0xB251),
("blake2s-144", 0xB252),
("blake2s-152", 0xB253),
("blake2s-160", 0xB254),
("blake2s-168", 0xB255),
("blake2s-176", 0xB256),
("blake2s-184", 0xB257),
("blake2s-192", 0xB258),
("blake2s-200", 0xB259),
("blake2s-208", 0xB25A),
("blake2s-216", 0xB25B),
("blake2s-224", 0xB25C),
("blake2s-232", 0xB25D),
("blake2s-240", 0xB25E),
("blake2s-248", 0xB25F),
("blake2s-256", 0xB260),
("skein256-8", 0xB301),
("skein256-16", 0xB302),
("skein256-24", 0xB303),
("skein256-32", 0xB304),
("skein256-40", 0xB305),
("skein256-48", 0xB306),
("skein256-56", 0xB307),
("skein256-64", 0xB308),
("skein256-72", 0xB309),
("skein256-80", 0xB30A),
("skein256-88", 0xB30B),
("skein256-96", 0xB30C),
("skein256-104", 0xB30D),
("skein256-112", 0xB30E),
("skein256-120", 0xB30F),
("skein256-128", 0xB310),
("skein256-136", 0xB311),
("skein256-144", 0xB312),
("skein256-152", 0xB313),
("skein256-160", 0xB314),
("skein256-168", 0xB315),
("skein256-176", 0xB316),
("skein256-184", 0xB317),
("skein256-192", 0xB318),
("skein256-200", 0xB319),
("skein256-208", 0xB31A),
("skein256-216", 0xB31B),
("skein256-224", 0xB31C),
("skein256-232", 0xB31D),
("skein256-240", 0xB31E),
("skein256-248", 0xB31F),
("skein256-256", 0xB320),
("skein512-8", 0xB321),
("skein512-16", 0xB322),
("skein512-24", 0xB323),
("skein512-32", 0xB324),
("skein512-40", 0xB325),
("skein512-48", 0xB326),
("skein512-56", 0xB327),
("skein512-64", 0xB328),
("skein512-72", 0xB329),
("skein512-80", 0xB32A),
("skein512-88", 0xB32B),
("skein512-96", 0xB32C),
("skein512-104", 0xB32D),
("skein512-112", 0xB32E),
("skein512-120", 0xB32F),
("skein512-128", 0xB330),
("skein512-136", 0xB331),
("skein512-144", 0xB332),
("skein512-152", 0xB333),
("skein512-160", 0xB334),
("skein512-168", 0xB335),
("skein512-176", 0xB336),
("skein512-184", 0xB337),
("skein512-192", 0xB338),
("skein512-200", 0xB339),
("skein512-208", 0xB33A),
("skein512-216", 0xB33B),
("skein512-224", 0xB33C),
("skein512-232", 0xB33D),
("skein512-240", 0xB33E),
("skein512-248", 0xB33F),
("skein512-256", 0xB340),
("skein512-264", 0xB341),
("skein512-272", 0xB342),
("skein512-280", 0xB343),
("skein512-288", 0xB344),
("skein512-296", 0xB345),
("skein512-304", 0xB346),
("skein512-312", 0xB347),
("skein512-320", 0xB348),
("skein512-328", 0xB349),
("skein512-336", 0xB34A),
("skein512-344", 0xB34B),
("skein512-352", 0xB34C),
("skein512-360", 0xB34D),
("skein512-368", 0xB34E),
("skein512-376", 0xB34F),
("skein512-384", 0xB350),
("skein512-392", 0xB351),
("skein512-400", 0xB352),
("skein512-408", 0xB353),
("skein512-416", 0xB354),
("skein512-424", 0xB355),
("skein512-432", 0xB356),
("skein512-440", 0xB357),
("skein512-448", 0xB358),
("skein512-456", 0xB359),
("skein512-464", 0xB35A),
("skein512-472", 0xB35B),
("skein512-480", 0xB35C),
("skein512-488", 0xB35D),
("skein512-496", 0xB35E),
("skein512-504", 0xB35F),
("skein512-512", 0xB360),
("skein1024-8", 0xB361),
("skein1024-16", 0xB362),
("skein1024-24", 0xB363),
("skein1024-32", 0xB364),
("skein1024-40", 0xB365),
("skein1024-48", 0xB366),
("skein1024-56", 0xB367),
("skein1024-64", 0xB368),
("skein1024-72", 0xB369),
("skein1024-80", 0xB36A),
("skein1024-88", 0xB36B),
("skein1024-96", 0xB36C),
("skein1024-104", 0xB36D),
("skein1024-112", 0xB36E),
("skein1024-120", 0xB36F),
("skein1024-128", 0xB370),
("skein1024-136", 0xB371),
("skein1024-144", 0xB372),
("skein1024-152", 0xB373),
("skein1024-160", 0xB374),
("skein1024-168", 0xB375),
("skein1024-176", 0xB376),
("skein1024-184", 0xB377),
("skein1024-192", 0xB378),
("skein1024-200", 0xB379),
("skein1024-208", 0xB37A),
("skein1024-216", 0xB37B),
("skein1024-224", 0xB37C),
("skein1024-232", 0xB37D),
("skein1024-240", 0xB37E),
("skein1024-248", 0xB37F),
("skein1024-256", 0xB380),
("skein1024-264", 0xB381),
("skein1024-272", 0xB382),
("skein1024-280", 0xB383),
("skein1024-288", 0xB384),
("skein1024-296", 0xB385),
("skein1024-304", 0xB386),
("skein1024-312", 0xB387),
("skein1024-320", 0xB388),
("skein1024-328", 0xB389),
("skein1024-336", 0xB38A),
("skein1024-344", 0xB38B),
("skein1024-352", 0xB38C),
("skein1024-360", 0xB38D),
("skein1024-368", 0xB38E),
("skein1024-376", 0xB38F),
("skein1024-384", 0xB390),
("skein1024-392", 0xB391),
("skein1024-400", 0xB392),
("skein1024-408", 0xB393),
("skein1024-416", 0xB394),
("skein1024-424", 0xB395),
("skein1024-432", 0xB396),
("skein1024-440", 0xB397),
("skein1024-448", 0xB398),
("skein1024-456", 0xB399),
("skein1024-464", 0xB39A),
("skein1024-472", 0xB39B),
("skein1024-480", 0xB39C),
("skein1024-488", 0xB39D),
("skein1024-496", 0xB39E),
("skein1024-504", 0xB39F),
("skein1024-512", 0xB3A0),
("skein1024-520", 0xB3A1),
("skein1024-528", 0xB3A2),
("skein1024-536", 0xB3A3),
("skein1024-544", 0xB3A4),
("skein1024-552", 0xB3A5),
("skein1024-560", 0xB3A6),
("skein1024-568", 0xB3A7),
("skein1024-576", 0xB3A8),
("skein1024-584", 0xB3A9),
("skein1024-592", 0xB3AA),
("skein1024-600", 0xB3AB),
("skein1024-608", 0xB3AC),
("skein1024-616", 0xB3AD),
("skein1024-624", 0xB3AE),
("skein1024-632", 0xB3AF),
("skein1024-640", 0xB3B0),
("skein1024-648", 0xB3B1),
("skein1024-656", 0xB3B2),
("skein1024-664", 0xB3B3),
("skein1024-672", 0xB3B4),
("skein1024-680", 0xB3B5),
("skein1024-688", 0xB3B6),
("skein1024-696", 0xB3B7),
("skein1024-704", 0xB3B8),
("skein1024-712", 0xB3B9),
("skein1024-720", 0xB3BA),
("skein1024-728", 0xB3BB),
("skein1024-736", 0xB3BC),
("skein1024-744", 0xB3BD),
("skein1024-752", 0xB3BE),
("skein1024-760", 0xB3BF),
("skein1024-768", 0xB3C0),
("skein1024-776", 0xB3C1),
("skein1024-784", 0xB3C2),
("skein1024-792", 0xB3C3),
("skein1024-800", 0xB3C4),
("skein1024-808", 0xB3C5),
("skein1024-816", 0xB3C6),
("skein1024-824", 0xB3C7),
("skein1024-832", 0xB3C8),
("skein1024-840", 0xB3C9),
("skein1024-848", 0xB3CA),
("skein1024-856", 0xB3CB),
("skein1024-864", 0xB3CC),
("skein1024-872", 0xB3CD),
("skein1024-880", 0xB3CE),
("skein1024-888", 0xB3CF),
("skein1024-896", 0xB3D0),
("skein1024-904", 0xB3D1),
("skein1024-912", 0xB3D2),
("skein1024-920", 0xB3D3),
("skein1024-928", 0xB3D4),
("skein1024-936", 0xB3D5),
("skein1024-944", 0xB3D6),
("skein1024-952", 0xB3D7),
("skein1024-960", 0xB3D8),
("skein1024-968", 0xB3D9),
("skein1024-976", 0xB3DA),
("skein1024-984", 0xB3DB),
("skein1024-992", 0xB3DC),
("skein1024-1000", 0xB3DD),
("skein1024-1008", 0xB3DE),
("skein1024-1016", 0xB3DF),
("blake2b-8", 0xB201), ("blake2b-16", 0xB202), ("blake2b-24", 0xB203),
("blake2b-32", 0xB204), ("blake2b-40", 0xB205), ("blake2b-48", 0xB206),
("blake2b-56", 0xB207), ("blake2b-64", 0xB208), ("blake2b-72", 0xB209),
("blake2b-80", 0xB20A), ("blake2b-88", 0xB20B), ("blake2b-96", 0xB20C),
("blake2b-104", 0xB20D), ("blake2b-112", 0xB20E), ("blake2b-120", 0xB20F),
("blake2b-128", 0xB210), ("blake2b-136", 0xB211), ("blake2b-144", 0xB212),
("blake2b-152", 0xB213), ("blake2b-160", 0xB214), ("blake2b-168", 0xB215),
("blake2b-176", 0xB216), ("blake2b-184", 0xB217), ("blake2b-192", 0xB218),
("blake2b-200", 0xB219), ("blake2b-208", 0xB21A), ("blake2b-216", 0xB21B),
("blake2b-224", 0xB21C), ("blake2b-232", 0xB21D), ("blake2b-240", 0xB21E),
("blake2b-248", 0xB21F), ("blake2b-256", 0xB220), ("blake2b-264", 0xB221),
("blake2b-272", 0xB222), ("blake2b-280", 0xB223), ("blake2b-288", 0xB224),
("blake2b-296", 0xB225), ("blake2b-304", 0xB226), ("blake2b-312", 0xB227),
("blake2b-320", 0xB228), ("blake2b-328", 0xB229), ("blake2b-336", 0xB22A),
("blake2b-344", 0xB22B), ("blake2b-352", 0xB22C), ("blake2b-360", 0xB22D),
("blake2b-368", 0xB22E), ("blake2b-376", 0xB22F), ("blake2b-384", 0xB230),
("blake2b-392", 0xB231), ("blake2b-400", 0xB232), ("blake2b-408", 0xB233),
("blake2b-416", 0xB234), ("blake2b-424", 0xB235), ("blake2b-432", 0xB236),
("blake2b-440", 0xB237), ("blake2b-448", 0xB238), ("blake2b-456", 0xB239),
("blake2b-464", 0xB23A), ("blake2b-472", 0xB23B), ("blake2b-480", 0xB23C),
("blake2b-488", 0xB23D), ("blake2b-496", 0xB23E), ("blake2b-504", 0xB23F),
("blake2b-512", 0xB240), ("blake2s-8", 0xB241), ("blake2s-16", 0xB242),
("blake2s-24", 0xB243), ("blake2s-32", 0xB244), ("blake2s-40", 0xB245),
("blake2s-48", 0xB246), ("blake2s-56", 0xB247), ("blake2s-64", 0xB248),
("blake2s-72", 0xB249), ("blake2s-80", 0xB24A), ("blake2s-88", 0xB24B),
("blake2s-96", 0xB24C), ("blake2s-104", 0xB24D), ("blake2s-112", 0xB24E),
("blake2s-120", 0xB24F), ("blake2s-128", 0xB250), ("blake2s-136", 0xB251),
("blake2s-144", 0xB252), ("blake2s-152", 0xB253), ("blake2s-160", 0xB254),
("blake2s-168", 0xB255), ("blake2s-176", 0xB256), ("blake2s-184", 0xB257),
("blake2s-192", 0xB258), ("blake2s-200", 0xB259), ("blake2s-208", 0xB25A),
("blake2s-216", 0xB25B), ("blake2s-224", 0xB25C), ("blake2s-232", 0xB25D),
("blake2s-240", 0xB25E), ("blake2s-248", 0xB25F), ("blake2s-256", 0xB260),
("skein256-8", 0xB301), ("skein256-16", 0xB302), ("skein256-24", 0xB303),
("skein256-32", 0xB304), ("skein256-40", 0xB305), ("skein256-48", 0xB306),
("skein256-56", 0xB307), ("skein256-64", 0xB308), ("skein256-72", 0xB309),
("skein256-80", 0xB30A), ("skein256-88", 0xB30B), ("skein256-96", 0xB30C),
("skein256-104", 0xB30D), ("skein256-112", 0xB30E), ("skein256-120", 0xB30F),
("skein256-128", 0xB310), ("skein256-136", 0xB311), ("skein256-144", 0xB312),
("skein256-152", 0xB313), ("skein256-160", 0xB314), ("skein256-168", 0xB315),
("skein256-176", 0xB316), ("skein256-184", 0xB317), ("skein256-192", 0xB318),
("skein256-200", 0xB319), ("skein256-208", 0xB31A), ("skein256-216", 0xB31B),
("skein256-224", 0xB31C), ("skein256-232", 0xB31D), ("skein256-240", 0xB31E),
("skein256-248", 0xB31F), ("skein256-256", 0xB320),
("skein512-8", 0xB321), ("skein512-16", 0xB322), ("skein512-24", 0xB323),
("skein512-32", 0xB324), ("skein512-40", 0xB325), ("skein512-48", 0xB326),
("skein512-56", 0xB327), ("skein512-64", 0xB328), ("skein512-72", 0xB329),
("skein512-80", 0xB32A), ("skein512-88", 0xB32B), ("skein512-96", 0xB32C),
("skein512-104", 0xB32D), ("skein512-112", 0xB32E), ("skein512-120", 0xB32F),
("skein512-128", 0xB330), ("skein512-136", 0xB331), ("skein512-144", 0xB332),
("skein512-152", 0xB333), ("skein512-160", 0xB334), ("skein512-168", 0xB335),
("skein512-176", 0xB336), ("skein512-184", 0xB337), ("skein512-192", 0xB338),
("skein512-200", 0xB339), ("skein512-208", 0xB33A), ("skein512-216", 0xB33B),
("skein512-224", 0xB33C), ("skein512-232", 0xB33D), ("skein512-240", 0xB33E),
("skein512-248", 0xB33F), ("skein512-256", 0xB340), ("skein512-264", 0xB341),
("skein512-272", 0xB342), ("skein512-280", 0xB343), ("skein512-288", 0xB344),
("skein512-296", 0xB345), ("skein512-304", 0xB346), ("skein512-312", 0xB347),
("skein512-320", 0xB348), ("skein512-328", 0xB349), ("skein512-336", 0xB34A),
("skein512-344", 0xB34B), ("skein512-352", 0xB34C), ("skein512-360", 0xB34D),
("skein512-368", 0xB34E), ("skein512-376", 0xB34F), ("skein512-384", 0xB350),
("skein512-392", 0xB351), ("skein512-400", 0xB352), ("skein512-408", 0xB353),
("skein512-416", 0xB354), ("skein512-424", 0xB355), ("skein512-432", 0xB356),
("skein512-440", 0xB357), ("skein512-448", 0xB358), ("skein512-456", 0xB359),
("skein512-464", 0xB35A), ("skein512-472", 0xB35B), ("skein512-480", 0xB35C),
("skein512-488", 0xB35D), ("skein512-496", 0xB35E), ("skein512-504", 0xB35F),
("skein512-512", 0xB360), ("skein1024-8", 0xB361), ("skein1024-16", 0xB362),
("skein1024-24", 0xB363), ("skein1024-32", 0xB364), ("skein1024-40", 0xB365),
("skein1024-48", 0xB366), ("skein1024-56", 0xB367), ("skein1024-64", 0xB368),
("skein1024-72", 0xB369), ("skein1024-80", 0xB36A), ("skein1024-88", 0xB36B),
("skein1024-96", 0xB36C), ("skein1024-104", 0xB36D),
("skein1024-112", 0xB36E), ("skein1024-120", 0xB36F),
("skein1024-128", 0xB370), ("skein1024-136", 0xB371),
("skein1024-144", 0xB372), ("skein1024-152", 0xB373),
("skein1024-160", 0xB374), ("skein1024-168", 0xB375),
("skein1024-176", 0xB376), ("skein1024-184", 0xB377),
("skein1024-192", 0xB378), ("skein1024-200", 0xB379),
("skein1024-208", 0xB37A), ("skein1024-216", 0xB37B),
("skein1024-224", 0xB37C), ("skein1024-232", 0xB37D),
("skein1024-240", 0xB37E), ("skein1024-248", 0xB37F),
("skein1024-256", 0xB380), ("skein1024-264", 0xB381),
("skein1024-272", 0xB382), ("skein1024-280", 0xB383),
("skein1024-288", 0xB384), ("skein1024-296", 0xB385),
("skein1024-304", 0xB386), ("skein1024-312", 0xB387),
("skein1024-320", 0xB388), ("skein1024-328", 0xB389),
("skein1024-336", 0xB38A), ("skein1024-344", 0xB38B),
("skein1024-352", 0xB38C), ("skein1024-360", 0xB38D),
("skein1024-368", 0xB38E), ("skein1024-376", 0xB38F),
("skein1024-384", 0xB390), ("skein1024-392", 0xB391),
("skein1024-400", 0xB392), ("skein1024-408", 0xB393),
("skein1024-416", 0xB394), ("skein1024-424", 0xB395),
("skein1024-432", 0xB396), ("skein1024-440", 0xB397),
("skein1024-448", 0xB398), ("skein1024-456", 0xB399),
("skein1024-464", 0xB39A), ("skein1024-472", 0xB39B),
("skein1024-480", 0xB39C), ("skein1024-488", 0xB39D),
("skein1024-496", 0xB39E), ("skein1024-504", 0xB39F),
("skein1024-512", 0xB3A0), ("skein1024-520", 0xB3A1),
("skein1024-528", 0xB3A2), ("skein1024-536", 0xB3A3),
("skein1024-544", 0xB3A4), ("skein1024-552", 0xB3A5),
("skein1024-560", 0xB3A6), ("skein1024-568", 0xB3A7),
("skein1024-576", 0xB3A8), ("skein1024-584", 0xB3A9),
("skein1024-592", 0xB3AA), ("skein1024-600", 0xB3AB),
("skein1024-608", 0xB3AC), ("skein1024-616", 0xB3AD),
("skein1024-624", 0xB3AE), ("skein1024-632", 0xB3AF),
("skein1024-640", 0xB3B0), ("skein1024-648", 0xB3B1),
("skein1024-656", 0xB3B2), ("skein1024-664", 0xB3B3),
("skein1024-672", 0xB3B4), ("skein1024-680", 0xB3B5),
("skein1024-688", 0xB3B6), ("skein1024-696", 0xB3B7),
("skein1024-704", 0xB3B8), ("skein1024-712", 0xB3B9),
("skein1024-720", 0xB3BA), ("skein1024-728", 0xB3BB),
("skein1024-736", 0xB3BC), ("skein1024-744", 0xB3BD),
("skein1024-752", 0xB3BE), ("skein1024-760", 0xB3BF),
("skein1024-768", 0xB3C0), ("skein1024-776", 0xB3C1),
("skein1024-784", 0xB3C2), ("skein1024-792", 0xB3C3),
("skein1024-800", 0xB3C4), ("skein1024-808", 0xB3C5),
("skein1024-816", 0xB3C6), ("skein1024-824", 0xB3C7),
("skein1024-832", 0xB3C8), ("skein1024-840", 0xB3C9),
("skein1024-848", 0xB3CA), ("skein1024-856", 0xB3CB),
("skein1024-864", 0xB3CC), ("skein1024-872", 0xB3CD),
("skein1024-880", 0xB3CE), ("skein1024-888", 0xB3CF),
("skein1024-896", 0xB3D0), ("skein1024-904", 0xB3D1),
("skein1024-912", 0xB3D2), ("skein1024-920", 0xB3D3),
("skein1024-928", 0xB3D4), ("skein1024-936", 0xB3D5),
("skein1024-944", 0xB3D6), ("skein1024-952", 0xB3D7),
("skein1024-960", 0xB3D8), ("skein1024-968", 0xB3D9),
("skein1024-976", 0xB3DA), ("skein1024-984", 0xB3DB),
("skein1024-992", 0xB3DC), ("skein1024-1000", 0xB3DD),
("skein1024-1008", 0xB3DE), ("skein1024-1016", 0xB3DF),
("skein1024-1024", 0xB3E0),
# multiaddrs
("ip4", 0x04),
@@ -384,11 +196,9 @@ const MultiCodecList = [
("p2p", 0x01A5),
("http", 0x01E0),
("https", 0x01BB),
("tls", 0x01C0),
("quic", 0x01CC),
("quic-v1", 0x01CD),
("ws", 0x01DD),
("wss", 0x01DE),
("wss", 0x01DE), # not in multicodec list
("p2p-websocket-star", 0x01DF), # not in multicodec list
("p2p-webrtc-star", 0x0113), # not in multicodec list
("p2p-webrtc-direct", 0x0114), # not in multicodec list
@@ -426,7 +236,7 @@ const MultiCodecList = [
("dash-tx", 0xF1),
("torrent-info", 0x7B),
("torrent-file", 0x7C),
("ed25519-pub", 0xED),
("ed25519-pub", 0xED)
]
type
@@ -434,7 +244,8 @@ type
MultiCodecError* = enum
MultiCodecNotSupported
const InvalidMultiCodec* = MultiCodec(-1)
const
InvalidMultiCodec* = MultiCodec(-1)
proc initMultiCodecNameTable(): Table[string, int] {.compileTime.} =
for item in MultiCodecList:
@@ -481,6 +292,10 @@ proc `==`*(a, b: MultiCodec): bool =
## Returns ``true`` if MultiCodecs ``a`` and ``b`` are equal.
int(a) == int(b)
proc `!=`*(a, b: MultiCodec): bool =
## Returns ``true`` if MultiCodecs ``a`` and ``b`` are not equal.
int(a) != int(b)
proc hash*(m: MultiCodec): Hash {.inline.} =
## Hash procedure for tables.
hash(int(m))

View File

@@ -21,7 +21,10 @@
## 1. SKEIN
## 2. MURMUR
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import tables
import nimcrypto/[sha, sha2, keccak, blake2, hash, utils]
@@ -41,9 +44,8 @@ const
ErrParseError = "Parse error fromHex"
type
MHashCoderProc* = proc(data: openArray[byte], output: var openArray[byte]) {.
nimcall, gcsafe, noSideEffect, raises: []
.}
MHashCoderProc* = proc(data: openArray[byte],
output: var openArray[byte]) {.nimcall, gcsafe, noSideEffect, raises: [Defect].}
MHash* = object
mcodec*: MultiCodec
size*: int
@@ -59,152 +61,107 @@ type
proc identhash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var length =
if len(data) > len(output):
len(output)
else:
len(data)
var length = if len(data) > len(output): len(output)
else: len(data)
copyMem(addr output[0], unsafeAddr data[0], length)
proc sha1hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = sha1.digest(data)
var length =
if sha1.sizeDigest > len(output):
len(output)
else:
sha1.sizeDigest
var length = if sha1.sizeDigest > len(output): len(output)
else: sha1.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc dblsha2_256hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest1 = sha256.digest(data)
var digest2 = sha256.digest(digest1.data)
var length =
if sha256.sizeDigest > len(output):
len(output)
else:
sha256.sizeDigest
var length = if sha256.sizeDigest > len(output): len(output)
else: sha256.sizeDigest
copyMem(addr output[0], addr digest2.data[0], length)
proc blake2Bhash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = blake2_512.digest(data)
var length =
if blake2_512.sizeDigest > len(output):
len(output)
else:
blake2_512.sizeDigest
var length = if blake2_512.sizeDigest > len(output): len(output)
else: blake2_512.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc blake2Shash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = blake2_256.digest(data)
var length =
if blake2_256.sizeDigest > len(output):
len(output)
else:
blake2_256.sizeDigest
var length = if blake2_256.sizeDigest > len(output): len(output)
else: blake2_256.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc sha2_256hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = sha256.digest(data)
var length =
if sha256.sizeDigest > len(output):
len(output)
else:
sha256.sizeDigest
var length = if sha256.sizeDigest > len(output): len(output)
else: sha256.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc sha2_512hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = sha512.digest(data)
var length =
if sha512.sizeDigest > len(output):
len(output)
else:
sha512.sizeDigest
var length = if sha512.sizeDigest > len(output): len(output)
else: sha512.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc sha3_224hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = sha3_224.digest(data)
var length =
if sha3_224.sizeDigest > len(output):
len(output)
else:
sha3_224.sizeDigest
var length = if sha3_224.sizeDigest > len(output): len(output)
else: sha3_224.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc sha3_256hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = sha3_256.digest(data)
var length =
if sha3_256.sizeDigest > len(output):
len(output)
else:
sha3_256.sizeDigest
var length = if sha3_256.sizeDigest > len(output): len(output)
else: sha3_256.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc sha3_384hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = sha3_384.digest(data)
var length =
if sha3_384.sizeDigest > len(output):
len(output)
else:
sha3_384.sizeDigest
var length = if sha3_384.sizeDigest > len(output): len(output)
else: sha3_384.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc sha3_512hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = sha3_512.digest(data)
var length =
if sha3_512.sizeDigest > len(output):
len(output)
else:
sha3_512.sizeDigest
var length = if sha3_512.sizeDigest > len(output): len(output)
else: sha3_512.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc keccak_224hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = keccak224.digest(data)
var length =
if keccak224.sizeDigest > len(output):
len(output)
else:
keccak224.sizeDigest
var length = if keccak224.sizeDigest > len(output): len(output)
else: keccak224.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc keccak_256hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = keccak256.digest(data)
var length =
if keccak256.sizeDigest > len(output):
len(output)
else:
keccak256.sizeDigest
var length = if keccak256.sizeDigest > len(output): len(output)
else: keccak256.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc keccak_384hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = keccak384.digest(data)
var length =
if keccak384.sizeDigest > len(output):
len(output)
else:
keccak384.sizeDigest
var length = if keccak384.sizeDigest > len(output): len(output)
else: keccak384.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc keccak_512hash(data: openArray[byte], output: var openArray[byte]) =
if len(output) > 0:
var digest = keccak512.digest(data)
var length =
if keccak512.sizeDigest > len(output):
len(output)
else:
keccak512.sizeDigest
var length = if keccak512.sizeDigest > len(output): len(output)
else: keccak512.sizeDigest
copyMem(addr output[0], addr digest.data[0], length)
proc shake_128hash(data: openArray[byte], output: var openArray[byte]) =
@@ -225,135 +182,151 @@ proc shake_256hash(data: openArray[byte], output: var openArray[byte]) =
discard sctx.output(addr output[0], uint(len(output)))
sctx.clear()
const HashesList = [
MHash(mcodec: multiCodec("identity"), size: 0, coder: identhash),
MHash(mcodec: multiCodec("sha1"), size: sha1.sizeDigest, coder: sha1hash),
MHash(
mcodec: multiCodec("dbl-sha2-256"), size: sha256.sizeDigest, coder: dblsha2_256hash
),
MHash(mcodec: multiCodec("sha2-256"), size: sha256.sizeDigest, coder: sha2_256hash),
MHash(mcodec: multiCodec("sha2-512"), size: sha512.sizeDigest, coder: sha2_512hash),
MHash(mcodec: multiCodec("sha3-224"), size: sha3_224.sizeDigest, coder: sha3_224hash),
MHash(mcodec: multiCodec("sha3-256"), size: sha3_256.sizeDigest, coder: sha3_256hash),
MHash(mcodec: multiCodec("sha3-384"), size: sha3_384.sizeDigest, coder: sha3_384hash),
MHash(mcodec: multiCodec("sha3-512"), size: sha3_512.sizeDigest, coder: sha3_512hash),
MHash(mcodec: multiCodec("shake-128"), size: 32, coder: shake_128hash),
MHash(mcodec: multiCodec("shake-256"), size: 64, coder: shake_256hash),
MHash(
mcodec: multiCodec("keccak-224"), size: keccak224.sizeDigest, coder: keccak_224hash
),
MHash(
mcodec: multiCodec("keccak-256"), size: keccak256.sizeDigest, coder: keccak_256hash
),
MHash(
mcodec: multiCodec("keccak-384"), size: keccak384.sizeDigest, coder: keccak_384hash
),
MHash(
mcodec: multiCodec("keccak-512"), size: keccak512.sizeDigest, coder: keccak_512hash
),
MHash(mcodec: multiCodec("blake2b-8"), size: 1, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-16"), size: 2, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-24"), size: 3, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-32"), size: 4, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-40"), size: 5, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-48"), size: 6, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-56"), size: 7, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-64"), size: 8, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-72"), size: 9, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-80"), size: 10, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-88"), size: 11, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-96"), size: 12, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-104"), size: 13, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-112"), size: 14, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-120"), size: 15, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-128"), size: 16, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-136"), size: 17, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-144"), size: 18, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-152"), size: 19, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-160"), size: 20, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-168"), size: 21, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-176"), size: 22, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-184"), size: 23, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-192"), size: 24, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-200"), size: 25, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-208"), size: 26, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-216"), size: 27, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-224"), size: 28, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-232"), size: 29, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-240"), size: 30, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-248"), size: 31, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-256"), size: 32, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-264"), size: 33, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-272"), size: 34, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-280"), size: 35, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-288"), size: 36, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-296"), size: 37, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-304"), size: 38, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-312"), size: 39, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-320"), size: 40, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-328"), size: 41, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-336"), size: 42, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-344"), size: 43, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-352"), size: 44, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-360"), size: 45, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-368"), size: 46, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-376"), size: 47, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-384"), size: 48, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-392"), size: 49, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-400"), size: 50, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-408"), size: 51, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-416"), size: 52, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-424"), size: 53, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-432"), size: 54, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-440"), size: 55, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-448"), size: 56, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-456"), size: 57, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-464"), size: 58, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-472"), size: 59, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-480"), size: 60, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-488"), size: 61, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-496"), size: 62, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-504"), size: 63, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-512"), size: 64, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2s-8"), size: 1, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-16"), size: 2, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-24"), size: 3, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-32"), size: 4, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-40"), size: 5, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-48"), size: 6, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-56"), size: 7, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-64"), size: 8, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-72"), size: 9, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-80"), size: 10, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-88"), size: 11, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-96"), size: 12, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-104"), size: 13, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-112"), size: 14, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-120"), size: 15, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-128"), size: 16, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-136"), size: 17, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-144"), size: 18, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-152"), size: 19, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-160"), size: 20, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-168"), size: 21, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-176"), size: 22, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-184"), size: 23, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-192"), size: 24, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-200"), size: 25, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-208"), size: 26, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-216"), size: 27, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-224"), size: 28, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-232"), size: 29, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-240"), size: 30, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-248"), size: 31, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-256"), size: 32, coder: blake2Shash),
]
const
HashesList = [
MHash(mcodec: multiCodec("identity"), size: 0,
coder: identhash),
MHash(mcodec: multiCodec("sha1"), size: sha1.sizeDigest,
coder: sha1hash),
MHash(mcodec: multiCodec("dbl-sha2-256"), size: sha256.sizeDigest,
coder: dblsha2_256hash
),
MHash(mcodec: multiCodec("sha2-256"), size: sha256.sizeDigest,
coder: sha2_256hash
),
MHash(mcodec: multiCodec("sha2-512"), size: sha512.sizeDigest,
coder: sha2_512hash
),
MHash(mcodec: multiCodec("sha3-224"), size: sha3_224.sizeDigest,
coder: sha3_224hash
),
MHash(mcodec: multiCodec("sha3-256"), size: sha3_256.sizeDigest,
coder: sha3_256hash
),
MHash(mcodec: multiCodec("sha3-384"), size: sha3_384.sizeDigest,
coder: sha3_384hash
),
MHash(mcodec: multiCodec("sha3-512"), size: sha3_512.sizeDigest,
coder: sha3_512hash
),
MHash(mcodec: multiCodec("shake-128"), size: 32, coder: shake_128hash),
MHash(mcodec: multiCodec("shake-256"), size: 64, coder: shake_256hash),
MHash(mcodec: multiCodec("keccak-224"), size: keccak224.sizeDigest,
coder: keccak_224hash
),
MHash(mcodec: multiCodec("keccak-256"), size: keccak256.sizeDigest,
coder: keccak_256hash
),
MHash(mcodec: multiCodec("keccak-384"), size: keccak384.sizeDigest,
coder: keccak_384hash
),
MHash(mcodec: multiCodec("keccak-512"), size: keccak512.sizeDigest,
coder: keccak_512hash
),
MHash(mcodec: multiCodec("blake2b-8"), size: 1, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-16"), size: 2, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-24"), size: 3, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-32"), size: 4, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-40"), size: 5, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-48"), size: 6, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-56"), size: 7, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-64"), size: 8, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-72"), size: 9, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-80"), size: 10, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-88"), size: 11, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-96"), size: 12, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-104"), size: 13, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-112"), size: 14, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-120"), size: 15, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-128"), size: 16, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-136"), size: 17, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-144"), size: 18, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-152"), size: 19, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-160"), size: 20, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-168"), size: 21, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-176"), size: 22, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-184"), size: 23, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-192"), size: 24, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-200"), size: 25, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-208"), size: 26, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-216"), size: 27, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-224"), size: 28, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-232"), size: 29, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-240"), size: 30, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-248"), size: 31, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-256"), size: 32, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-264"), size: 33, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-272"), size: 34, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-280"), size: 35, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-288"), size: 36, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-296"), size: 37, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-304"), size: 38, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-312"), size: 39, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-320"), size: 40, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-328"), size: 41, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-336"), size: 42, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-344"), size: 43, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-352"), size: 44, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-360"), size: 45, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-368"), size: 46, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-376"), size: 47, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-384"), size: 48, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-392"), size: 49, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-400"), size: 50, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-408"), size: 51, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-416"), size: 52, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-424"), size: 53, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-432"), size: 54, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-440"), size: 55, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-448"), size: 56, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-456"), size: 57, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-464"), size: 58, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-472"), size: 59, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-480"), size: 60, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-488"), size: 61, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-496"), size: 62, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-504"), size: 63, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2b-512"), size: 64, coder: blake2Bhash),
MHash(mcodec: multiCodec("blake2s-8"), size: 1, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-16"), size: 2, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-24"), size: 3, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-32"), size: 4, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-40"), size: 5, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-48"), size: 6, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-56"), size: 7, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-64"), size: 8, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-72"), size: 9, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-80"), size: 10, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-88"), size: 11, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-96"), size: 12, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-104"), size: 13, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-112"), size: 14, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-120"), size: 15, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-128"), size: 16, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-136"), size: 17, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-144"), size: 18, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-152"), size: 19, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-160"), size: 20, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-168"), size: 21, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-176"), size: 22, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-184"), size: 23, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-192"), size: 24, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-200"), size: 25, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-208"), size: 26, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-216"), size: 27, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-224"), size: 28, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-232"), size: 29, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-240"), size: 30, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-248"), size: 31, coder: blake2Shash),
MHash(mcodec: multiCodec("blake2s-256"), size: 32, coder: blake2Shash)
]
proc initMultiHashCodeTable(): Table[MultiCodec, MHash] {.compileTime.} =
for item in HashesList:
result[item.mcodec] = item
const CodeHashes = initMultiHashCodeTable()
const
CodeHashes = initMultiHashCodeTable()
proc digestImplWithHash(hash: MHash, data: openArray[byte]): MultiHash =
var buffer: array[MaxHashSize, byte]
@@ -383,9 +356,8 @@ proc digestImplWithoutHash(hash: MHash, data: openArray[byte]): MultiHash =
result.data.writeArray(data)
result.data.finish()
proc digest*(
mhtype: typedesc[MultiHash], hashname: string, data: openArray[byte]
): MhResult[MultiHash] {.inline.} =
proc digest*(mhtype: typedesc[MultiHash], hashname: string,
data: openArray[byte]): MhResult[MultiHash] {.inline.} =
## Perform digest calculation using hash algorithm with name ``hashname`` on
## data array ``data``.
let mc = MultiCodec.codec(hashname)
@@ -398,9 +370,8 @@ proc digest*(
else:
ok(digestImplWithHash(hash, data))
proc digest*(
mhtype: typedesc[MultiHash], hashcode: int, data: openArray[byte]
): MhResult[MultiHash] {.inline.} =
proc digest*(mhtype: typedesc[MultiHash], hashcode: int,
data: openArray[byte]): MhResult[MultiHash] {.inline.} =
## Perform digest calculation using hash algorithm with code ``hashcode`` on
## data array ``data``.
let hash = CodeHashes.getOrDefault(hashcode)
@@ -409,9 +380,8 @@ proc digest*(
else:
ok(digestImplWithHash(hash, data))
proc init*[T](
mhtype: typedesc[MultiHash], hashname: string, mdigest: MDigest[T]
): MhResult[MultiHash] {.inline.} =
proc init*[T](mhtype: typedesc[MultiHash], hashname: string,
mdigest: MDigest[T]): MhResult[MultiHash] {.inline.} =
## Create MultiHash from nimcrypto's `MDigest` object and hash algorithm name
## ``hashname``.
let mc = MultiCodec.codec(hashname)
@@ -426,9 +396,8 @@ proc init*[T](
else:
ok(digestImplWithoutHash(hash, mdigest.data))
proc init*[T](
mhtype: typedesc[MultiHash], hashcode: MultiCodec, mdigest: MDigest[T]
): MhResult[MultiHash] {.inline.} =
proc init*[T](mhtype: typedesc[MultiHash], hashcode: MultiCodec,
mdigest: MDigest[T]): MhResult[MultiHash] {.inline.} =
## Create MultiHash from nimcrypto's `MDigest` and hash algorithm code
## ``hashcode``.
let hash = CodeHashes.getOrDefault(hashcode)
@@ -439,9 +408,8 @@ proc init*[T](
else:
ok(digestImplWithoutHash(hash, mdigest.data))
proc init*(
mhtype: typedesc[MultiHash], hashname: string, bdigest: openArray[byte]
): MhResult[MultiHash] {.inline.} =
proc init*(mhtype: typedesc[MultiHash], hashname: string,
bdigest: openArray[byte]): MhResult[MultiHash] {.inline.} =
## Create MultiHash from array of bytes ``bdigest`` and hash algorithm code
## ``hashcode``.
let mc = MultiCodec.codec(hashname)
@@ -456,9 +424,8 @@ proc init*(
else:
ok(digestImplWithoutHash(hash, bdigest))
proc init*(
mhtype: typedesc[MultiHash], hashcode: MultiCodec, bdigest: openArray[byte]
): MhResult[MultiHash] {.inline.} =
proc init*(mhtype: typedesc[MultiHash], hashcode: MultiCodec,
bdigest: openArray[byte]): MhResult[MultiHash] {.inline.} =
## Create MultiHash from array of bytes ``bdigest`` and hash algorithm code
## ``hashcode``.
let hash = CodeHashes.getOrDefault(hashcode)
@@ -469,9 +436,8 @@ proc init*(
else:
ok(digestImplWithoutHash(hash, bdigest))
proc decode*(
mhtype: typedesc[MultiHash], data: openArray[byte], mhash: var MultiHash
): MhResult[int] =
proc decode*(mhtype: typedesc[MultiHash], data: openArray[byte],
mhash: var MultiHash): MhResult[int] =
## Decode MultiHash value from array of bytes ``data``.
##
## On success decoded MultiHash will be stored into ``mhash`` and number of
@@ -510,10 +476,9 @@ proc decode*(
if not vb.isEnough(int(size)):
return err(ErrDecodeError)
mhash =
?MultiHash.init(
MultiCodec(code), vb.buffer.toOpenArray(vb.offset, vb.offset + int(size) - 1)
)
mhash = ? MultiHash.init(MultiCodec(code),
vb.buffer.toOpenArray(vb.offset,
vb.offset + int(size) - 1))
ok(vb.offset + int(size))
proc validate*(mhtype: typedesc[MultiHash], data: openArray[byte]): bool =
@@ -546,24 +511,24 @@ proc validate*(mhtype: typedesc[MultiHash], data: openArray[byte]): bool =
return false
result = true
proc init*(
mhtype: typedesc[MultiHash], data: openArray[byte]
): MhResult[MultiHash] {.inline.} =
proc init*(mhtype: typedesc[MultiHash],
data: openArray[byte]): MhResult[MultiHash] {.inline.} =
## Create MultiHash from bytes array ``data``.
var hash: MultiHash
discard ?MultiHash.decode(data, hash)
discard ? MultiHash.decode(data, hash)
ok(hash)
proc init*(mhtype: typedesc[MultiHash], data: string): MhResult[MultiHash] {.inline.} =
## Create MultiHash from hexadecimal string representation ``data``.
var hash: MultiHash
try:
discard ?MultiHash.decode(fromHex(data), hash)
discard ? MultiHash.decode(fromHex(data), hash)
ok(hash)
except ValueError:
err(ErrParseError)
proc init58*(mhtype: typedesc[MultiHash], data: string): MultiHash {.inline.} =
proc init58*(mhtype: typedesc[MultiHash],
data: string): MultiHash {.inline.} =
## Create MultiHash from BASE58 encoded string representation ``data``.
if MultiHash.decode(Base58.decode(data), result) == -1:
raise newException(MultihashError, "Incorrect MultiHash binary format")
@@ -576,7 +541,7 @@ proc cmp(a: openArray[byte], b: openArray[byte]): bool {.inline.} =
while n > 0:
dec(n)
diff = int(a[n]) - int(b[n])
res = (res and - not (diff)) or diff
res = (res and -not(diff)) or diff
result = (res == 0)
proc `==`*[T](mh: MultiHash, mdigest: MDigest[T]): bool =
@@ -586,10 +551,8 @@ proc `==`*[T](mh: MultiHash, mdigest: MDigest[T]): bool =
return false
if len(mdigest.data) != mh.size:
return false
result = cmp(
mh.data.buffer.toOpenArray(mh.dpos, mh.dpos + mh.size - 1),
mdigest.data.toOpenArray(0, mdigest.data.high),
)
result = cmp(mh.data.buffer.toOpenArray(mh.dpos, mh.dpos + mh.size - 1),
mdigest.data.toOpenArray(0, mdigest.data.high))
proc `==`*[T](mdigest: MDigest[T], mh: MultiHash): bool {.inline.} =
## Compares MultiHash with nimcrypto's MDigest[T], returns ``true`` if
@@ -605,10 +568,8 @@ proc `==`*(a: MultiHash, b: MultiHash): bool =
return false
if a.size != b.size:
return false
result = cmp(
a.data.buffer.toOpenArray(a.dpos, a.dpos + a.size - 1),
b.data.buffer.toOpenArray(b.dpos, b.dpos + b.size - 1),
)
result = cmp(a.data.buffer.toOpenArray(a.dpos, a.dpos + a.size - 1),
b.data.buffer.toOpenArray(b.dpos, b.dpos + b.size - 1))
proc hex*(value: MultiHash): string =
## Return hexadecimal string representation of MultiHash ``value``.
@@ -620,16 +581,16 @@ proc base58*(value: MultiHash): string =
proc `$`*(mh: MultiHash): string =
## Return string representation of MultiHash ``value``.
let digest = toHex(mh.data.buffer.toOpenArray(mh.dpos, mh.dpos + mh.size - 1))
let digest = toHex(mh.data.buffer.toOpenArray(mh.dpos,
mh.dpos + mh.size - 1))
result = $(mh.mcodec) & "/" & digest
proc write*(vb: var VBuffer, mh: MultiHash) {.inline.} =
## Write MultiHash value ``mh`` to buffer ``vb``.
vb.writeArray(mh.data.buffer)
proc encode*(
mbtype: typedesc[MultiBase], encoding: string, mh: MultiHash
): string {.inline.} =
proc encode*(mbtype: typedesc[MultiBase], encoding: string,
mh: MultiHash): string {.inline.} =
## Get MultiBase encoded representation of ``mh`` using encoding
## ``encoding``.
result = MultiBase.encode(encoding, mh.data.buffer)

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,24 +7,29 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[strutils, sequtils, tables]
import chronos, chronicles, stew/byteutils
import stream/connection, protocols/protocol
import stream/connection,
protocols/protocol
logScope:
topics = "libp2p multistream"
const
MsgSize = 1024
Codec = "/multistream/1.0.0"
MsgSize* = 1024
Codec* = "/multistream/1.0.0"
Na = "na\n"
Ls = "ls\n"
MSCodec* = "\x13" & Codec & "\n"
Na* = "\x03na\n"
Ls* = "\x03ls\n"
type
Matcher* = proc(proto: string): bool {.gcsafe, raises: [].}
Matcher* = proc (proto: string): bool {.gcsafe, raises: [Defect].}
MultiStreamError* = object of LPError
@@ -39,20 +44,23 @@ type
codec*: string
proc new*(T: typedesc[MultistreamSelect]): T =
T(codec: Codec)
T(
codec: MSCodec,
)
template validateSuffix(str: string): untyped =
if str.endsWith("\n"):
str.removeSuffix("\n")
else:
raise (ref MultiStreamError)(msg: "MultistreamSelect failed, malformed message")
if str.endsWith("\n"):
str.removeSuffix("\n")
else:
raise newException(MultiStreamError, "MultistreamSelect failed, malformed message")
proc select*(
_: MultistreamSelect | type MultistreamSelect, conn: Connection, proto: seq[string]
): Future[string] {.async: (raises: [CancelledError, LPStreamError, MultiStreamError]).} =
trace "initiating handshake", conn, codec = Codec
proc select*(m: MultistreamSelect,
conn: Connection,
proto: seq[string]):
Future[string] {.async.} =
trace "initiating handshake", conn, codec = m.codec
## select a remote protocol
await conn.writeLp(Codec & "\n") # write handshake
await conn.write(m.codec) # write handshake
if proto.len() > 0:
trace "selecting proto", conn, proto = proto[0]
await conn.writeLp((proto[0] & "\n")) # select proto
@@ -62,7 +70,7 @@ proc select*(
if s != Codec:
notice "handshake failed", conn, codec = s
raise (ref MultiStreamError)(msg: "MultistreamSelect handshake failed")
raise newException(MultiStreamError, "MultistreamSelect handshake failed")
else:
trace "multistream handshake success", conn
@@ -78,7 +86,7 @@ proc select*(
return proto[0]
elif proto.len > 1:
# Try to negotiate alternatives
let protos = proto[1 ..< proto.len()]
let protos = proto[1..<proto.len()]
trace "selecting one of several protos", conn, protos = protos
for p in protos:
trace "selecting proto", conn, proto = p
@@ -94,31 +102,24 @@ proc select*(
# No alternatives, fail
return ""
proc select*(
_: MultistreamSelect | type MultistreamSelect, conn: Connection, proto: string
): Future[bool] {.async: (raises: [CancelledError, LPStreamError, MultiStreamError]).} =
proc select*(m: MultistreamSelect,
conn: Connection,
proto: string): Future[bool] {.async.} =
if proto.len > 0:
(await MultistreamSelect.select(conn, @[proto])) == proto
return (await m.select(conn, @[proto])) == proto
else:
(await MultistreamSelect.select(conn, @[])) == Codec
return (await m.select(conn, @[])) == Codec
proc select*(
m: MultistreamSelect, conn: Connection
): Future[bool] {.
async: (raises: [CancelledError, LPStreamError, MultiStreamError], raw: true)
.} =
proc select*(m: MultistreamSelect, conn: Connection): Future[bool] =
m.select(conn, "")
proc list*(
m: MultistreamSelect, conn: Connection
): Future[seq[string]] {.
async: (raises: [CancelledError, LPStreamError, MultiStreamError])
.} =
proc list*(m: MultistreamSelect,
conn: Connection): Future[seq[string]] {.async.} =
## list remote protos requests on connection
if not await m.select(conn):
return
await conn.writeLp(Ls) # send ls
await conn.write(Ls) # send ls
var list = newSeq[string]()
let ms = string.fromBytes(await conn.readLp(MsgSize))
@@ -128,87 +129,68 @@ proc list*(
result = list
proc handle*(
_: type MultistreamSelect,
conn: Connection,
protos: seq[string],
matchers = newSeq[Matcher](),
active: bool = false,
): Future[string] {.async: (raises: [CancelledError, LPStreamError, MultiStreamError]).} =
trace "Starting multistream negotiation", conn, handshaked = active
var handshaked = active
while not conn.atEof:
var ms = string.fromBytes(await conn.readLp(MsgSize))
validateSuffix(ms)
if not handshaked and ms != Codec:
debug "expected handshake message", conn, instead = ms
raise (ref MultiStreamError)(
msg: "MultistreamSelect handling failed, invalid first message"
)
trace "handle: got request", conn, ms
if ms.len() <= 0:
trace "handle: invalid proto", conn
await conn.writeLp(Na)
case ms
of "ls":
trace "handle: listing protos", conn
#TODO this doens't seem to follow spec, each protocol
# should be length prefixed. Not very important
# since LS is getting deprecated
await conn.writeLp(protos.join("\n") & "\n")
of Codec:
if not handshaked:
await conn.writeLp(Codec & "\n")
handshaked = true
else:
trace "handle: sending `na` for duplicate handshake while handshaked", conn
await conn.writeLp(Na)
elif ms in protos or matchers.anyIt(it(ms)):
trace "found handler", conn, protocol = ms
await conn.writeLp(ms & "\n")
conn.protocol = ms
return ms
else:
trace "no handlers", conn, protocol = ms
await conn.writeLp(Na)
proc handle*(
m: MultistreamSelect, conn: Connection, active: bool = false
) {.async: (raises: [CancelledError]).} =
proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.async, gcsafe.} =
trace "Starting multistream handler", conn, handshaked = active
var
protos: seq[string]
matchers: seq[Matcher]
for h in m.handlers:
if h.match != nil:
matchers.add(h.match)
for proto in h.protos:
protos.add(proto)
var handshaked = active
try:
let ms = await MultistreamSelect.handle(conn, protos, matchers, active)
for h in m.handlers:
if (h.match != nil and h.match(ms)) or h.protos.contains(ms):
trace "found handler", conn, protocol = ms
while not conn.atEof:
var ms = string.fromBytes(await conn.readLp(MsgSize))
validateSuffix(ms)
var protocolHolder = h
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
debug "Max streams for protocol reached, blocking new stream",
conn, protocol = ms, maxIncomingStreams
return
protocolHolder.openedStreams.inc(conn.peerId)
try:
await protocolHolder.protocol.handler(conn, ms)
finally:
protocolHolder.openedStreams.inc(conn.peerId, -1)
if protocolHolder.openedStreams[conn.peerId] == 0:
protocolHolder.openedStreams.del(conn.peerId)
return
debug "no handlers", conn, ms
if not handshaked and ms != Codec:
notice "expected handshake message", conn, instead=ms
raise newException(CatchableError,
"MultistreamSelect handling failed, invalid first message")
trace "handle: got request", conn, ms
if ms.len() <= 0:
trace "handle: invalid proto", conn
await conn.write(Na)
if m.handlers.len() == 0:
trace "handle: sending `na` for protocol", conn, protocol = ms
await conn.write(Na)
continue
case ms:
of "ls":
trace "handle: listing protos", conn
var protos = ""
for h in m.handlers:
for proto in h.protos:
protos &= (proto & "\n")
await conn.writeLp(protos)
of Codec:
if not handshaked:
await conn.write(m.codec)
handshaked = true
else:
trace "handle: sending `na` for duplicate handshake while handshaked",
conn
await conn.write(Na)
else:
for h in m.handlers:
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
trace "found handler", conn, protocol = ms
var protocolHolder = h
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
debug "Max streams for protocol reached, blocking new stream",
conn, protocol = ms, maxIncomingStreams
return
protocolHolder.openedStreams.inc(conn.peerId)
try:
await conn.writeLp(ms & "\n")
conn.protocol = ms
await protocolHolder.protocol.handler(conn, ms)
finally:
protocolHolder.openedStreams.inc(conn.peerId, -1)
if protocolHolder.openedStreams[conn.peerId] == 0:
protocolHolder.openedStreams.del(conn.peerId)
return
debug "no handlers", conn, protocol = ms
await conn.write(Na)
except CancelledError as exc:
raise exc
except CatchableError as exc:
@@ -218,65 +200,37 @@ proc handle*(
trace "Stopped multistream handler", conn
proc addHandler*(
m: MultistreamSelect,
codecs: seq[string],
protocol: LPProtocol,
matcher: Matcher = nil,
) =
proc addHandler*(m: MultistreamSelect,
codecs: seq[string],
protocol: LPProtocol,
matcher: Matcher = nil) =
trace "registering protocols", protos = codecs
m.handlers.add(HandlerHolder(protos: codecs, protocol: protocol, match: matcher))
m.handlers.add(HandlerHolder(protos: codecs,
protocol: protocol,
match: matcher))
proc addHandler*(
m: MultistreamSelect, codec: string, protocol: LPProtocol, matcher: Matcher = nil
) =
proc addHandler*(m: MultistreamSelect,
codec: string,
protocol: LPProtocol,
matcher: Matcher = nil) =
addHandler(m, @[codec], protocol, matcher)
proc addHandler*[E](
m: MultistreamSelect,
codec: string,
handler:
LPProtoHandler |
proc(conn: Connection, proto: string): InternalRaisesFuture[void, E],
matcher: Matcher = nil,
) =
proc addHandler*(m: MultistreamSelect,
codec: string,
handler: LPProtoHandler,
matcher: Matcher = nil) =
## helper to allow registering pure handlers
trace "registering proto handler", proto = codec
let protocol = new LPProtocol
protocol.codec = codec
protocol.handler = handler
m.handlers.add(HandlerHolder(protos: @[codec], protocol: protocol, match: matcher))
m.handlers.add(HandlerHolder(protos: @[codec],
protocol: protocol,
match: matcher))
proc start*(m: MultistreamSelect) {.async: (raises: [CancelledError]).} =
# Nim 1.6.18: Using `mapIt` results in a seq of `.Raising([])`
# TODO https://github.com/nim-lang/Nim/issues/23445
var futs = newSeqOfCap[Future[void].Raising([CancelledError])](m.handlers.len)
for it in m.handlers:
futs.add it.protocol.start()
try:
await allFutures(futs)
for fut in futs:
await fut
except CancelledError as exc:
var pending: seq[Future[void].Raising([])]
doAssert m.handlers.len == futs.len, "Handlers modified while starting"
for i, fut in futs:
if not fut.finished:
pending.add fut.cancelAndWait()
elif fut.completed:
pending.add m.handlers[i].protocol.stop()
else:
static:
doAssert typeof(fut).E is (CancelledError,)
await noCancel allFutures(pending)
raise exc
proc start*(m: MultistreamSelect) {.async.} =
await allFutures(m.handlers.mapIt(it.protocol.start()))
proc stop*(m: MultistreamSelect) {.async: (raises: []).} =
# Nim 1.6.18: Using `mapIt` results in a seq of `.Raising([CancelledError])`
var futs = newSeqOfCap[Future[void].Raising([])](m.handlers.len)
for it in m.handlers:
futs.add it.protocol.stop()
await noCancel allFutures(futs)
for fut in futs:
await fut
proc stop*(m: MultistreamSelect) {.async.} =
await allFutures(m.handlers.mapIt(it.protocol.stop()))

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,25 +7,35 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import pkg/[chronos, chronicles, stew/byteutils]
import ../../stream/connection, ../../utility, ../../varint, ../../vbuffer, ../muxer
import pkg/[chronos, nimcrypto/utils, chronicles, stew/byteutils]
import ../../stream/connection,
../../utility,
../../varint,
../../vbuffer,
../muxer
logScope:
topics = "libp2p mplexcoder"
type
MessageType* {.pure.} = enum
New
MsgIn
MsgOut
CloseIn
CloseOut
ResetIn
New,
MsgIn,
MsgOut,
CloseIn,
CloseOut,
ResetIn,
ResetOut
Msg* = tuple[id: uint64, msgType: MessageType, data: seq[byte]]
Msg* = tuple
id: uint64
msgType: MessageType
data: seq[byte]
InvalidMplexMsgType* = object of MuxerError
@@ -35,9 +45,7 @@ const MaxMsgSize* = 1 shl 20 # 1mb
proc newInvalidMplexMsgType*(): ref InvalidMplexMsgType =
newException(InvalidMplexMsgType, "invalid message type")
proc readMsg*(
conn: Connection
): Future[Msg] {.async: (raises: [CancelledError, LPStreamError, MuxerError]).} =
proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} =
let header = await conn.readVarint()
trace "read header varint", varint = header, conn
@@ -50,9 +58,10 @@ proc readMsg*(
return (header shr 3, MessageType(msgType), data)
proc writeMsg*(
conn: Connection, id: uint64, msgType: MessageType, data: seq[byte] = @[]
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
proc writeMsg*(conn: Connection,
id: uint64,
msgType: MessageType,
data: seq[byte] = @[]): Future[void] =
var
left = data.len
offset = 0
@@ -60,11 +69,8 @@ proc writeMsg*(
# Split message into length-prefixed chunks
while left > 0 or data.len == 0:
let chunkSize =
if left > MaxMsgSize:
MaxMsgSize - 64
else:
left
let
chunkSize = if left > MaxMsgSize: MaxMsgSize - 64 else: left
buf.writePBVarint(id shl 3 or ord(msgType).uint64)
buf.writeSeq(data.toOpenArray(offset, offset + chunkSize - 1))
@@ -81,7 +87,8 @@ proc writeMsg*(
# message gets written before some of the chunks
conn.write(buf.buffer)
proc writeMsg*(
conn: Connection, id: uint64, msgType: MessageType, data: string
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
proc writeMsg*(conn: Connection,
id: uint64,
msgType: MessageType,
data: string): Future[void] =
conn.writeMsg(id, msgType, data.toBytes())

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,12 +7,17 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[oids, strformat]
import pkg/[chronos, chronicles, metrics]
import
./coder, ../muxer, ../../stream/[bufferstream, connection, streamseq], ../../peerinfo
import pkg/[chronos, chronicles, metrics, nimcrypto/utils]
import ./coder,
../muxer,
../../stream/[bufferstream, connection, streamseq],
../../peerinfo
export connection
@@ -20,15 +25,13 @@ logScope:
topics = "libp2p mplexchannel"
when defined(libp2p_mplex_metrics):
declareHistogram libp2p_mplex_qlen,
"message queue length",
declareHistogram libp2p_mplex_qlen, "message queue length",
buckets = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0]
declareCounter libp2p_mplex_qlenclose, "closed because of max queuelen"
declareHistogram libp2p_mplex_qtime, "message queuing time"
when defined(libp2p_network_protocols_metrics):
declareCounter libp2p_protocols_bytes,
"total sent or received bytes", ["protocol", "direction"]
declareCounter libp2p_protocols_bytes, "total sent or received bytes", ["protocol", "direction"]
## Channel half-closed states
##
@@ -42,41 +45,39 @@ when defined(libp2p_network_protocols_metrics):
## EOF marker
const
MaxWrites = 1024
##\
MaxWrites = 1024 ##\
## Maximum number of in-flight writes - after this, we disconnect the peer
LPChannelTrackerName* = "LPChannel"
type LPChannel* = ref object of BufferStream
id*: uint64 # channel id
name*: string # name of the channel (for debugging)
conn*: Connection # wrapped connection used to for writing
initiator*: bool # initiated remotely or locally flag
isOpen*: bool # has channel been opened
closedLocal*: bool # has channel been closed locally
remoteReset*: bool # has channel been remotely reset
localReset*: bool # has channel been reset locally
msgCode*: MessageType # cached in/out message code
closeCode*: MessageType # cached in/out close code
resetCode*: MessageType # cached in/out reset code
writes*: int # In-flight writes
type
LPChannel* = ref object of BufferStream
id*: uint64 # channel id
name*: string # name of the channel (for debugging)
conn*: Connection # wrapped connection used to for writing
initiator*: bool # initiated remotely or locally flag
isOpen*: bool # has channel been opened
closedLocal*: bool # has channel been closed locally
remoteReset*: bool # has channel been remotely reset
localReset*: bool # has channel been reset locally
msgCode*: MessageType # cached in/out message code
closeCode*: MessageType # cached in/out close code
resetCode*: MessageType # cached in/out reset code
writes*: int # In-flight writes
writesBytes*: int # In-flight writes bytes
func shortLog*(s: LPChannel): auto =
try:
if s == nil:
"LPChannel(nil)"
if s.isNil: "LPChannel(nil)"
elif s.name != $s.oid and s.name.len > 0:
&"{shortLog(s.conn.peerId)}:{s.oid}:{s.name}"
else:
&"{shortLog(s.conn.peerId)}:{s.oid}"
else: &"{shortLog(s.conn.peerId)}:{s.oid}"
except ValueError as exc:
raiseAssert(exc.msg)
raise newException(Defect, exc.msg)
chronicles.formatIt(LPChannel):
shortLog(it)
chronicles.formatIt(LPChannel): shortLog(it)
proc open*(s: LPChannel) {.async: (raises: [CancelledError, LPStreamError]).} =
proc open*(s: LPChannel) {.async, gcsafe.} =
trace "Opening channel", s, conn = s.conn
if s.conn.isClosed:
return
@@ -85,20 +86,20 @@ proc open*(s: LPChannel) {.async: (raises: [CancelledError, LPStreamError]).} =
s.isOpen = true
except CancelledError as exc:
raise exc
except LPStreamError as exc:
except CatchableError as exc:
await s.conn.close()
raise exc
method closed*(s: LPChannel): bool =
s.closedLocal
proc closeUnderlying(s: LPChannel): Future[void] {.async: (raises: []).} =
proc closeUnderlying(s: LPChannel): Future[void] {.async.} =
## Channels may be closed for reading and writing in any order - we'll close
## the underlying bufferstream when both directions are closed
if s.closedLocal and s.atEof():
await procCall BufferStream(s).close()
proc reset*(s: LPChannel) {.async: (raises: []).} =
proc reset*(s: LPChannel) {.async, gcsafe.} =
if s.isClosed:
trace "Already closed", s
return
@@ -111,21 +112,22 @@ proc reset*(s: LPChannel) {.async: (raises: []).} =
if s.isOpen and not s.conn.isClosed:
# If the connection is still active, notify the other end
proc resetMessage() {.async: (raises: []).} =
proc resetMessage() {.async.} =
try:
trace "sending reset message", s, conn = s.conn
await noCancel s.conn.writeMsg(s.id, s.resetCode) # write reset
except LPStreamError as exc:
trace "Can't send reset message", s, conn = s.conn, msg = exc.msg
await s.conn.writeMsg(s.id, s.resetCode) # write reset
except CatchableError as exc:
# No cancellations
await s.conn.close()
trace "Can't send reset message", s, conn = s.conn, msg = exc.msg
asyncSpawn resetMessage()
await s.closeImpl()
await s.closeImpl() # noraises, nocancels
trace "Channel reset", s
method close*(s: LPChannel) {.async: (raises: []).} =
method close*(s: LPChannel) {.async, gcsafe.} =
## Close channel for writing - a message will be sent to the other peer
## informing them that the channel is closed and that we're waiting for
## their acknowledgement.
@@ -139,9 +141,10 @@ method close*(s: LPChannel) {.async: (raises: []).} =
if s.isOpen and not s.conn.isClosed:
try:
await s.conn.writeMsg(s.id, s.closeCode) # write close
except CancelledError:
except CancelledError as exc:
await s.conn.close()
except LPStreamError as exc:
raise exc
except CatchableError as exc:
# It's harmless that close message cannot be sent - the connection is
# likely down already
await s.conn.close()
@@ -155,15 +158,16 @@ method initStream*(s: LPChannel) =
if s.objName.len == 0:
s.objName = LPChannelTrackerName
s.timeoutHandler = proc(): Future[void] {.async: (raises: [], raw: true).} =
s.timeoutHandler = proc(): Future[void] {.gcsafe.} =
trace "Idle timeout expired, resetting LPChannel", s
s.reset()
procCall BufferStream(s).initStream()
method readOnce*(
s: LPChannel, pbytes: pointer, nbytes: int
): Future[int] {.async: (raises: [CancelledError, LPStreamError]).} =
method readOnce*(s: LPChannel,
pbytes: pointer,
nbytes: int):
Future[int] {.async.} =
## Mplex relies on reading being done regularly from every channel, or all
## channels are blocked - in particular, this means that reading from one
## channel must not be done from within a callback / read handler of another
@@ -180,24 +184,21 @@ method readOnce*(
let bytes = await procCall BufferStream(s).readOnce(pbytes, nbytes)
when defined(libp2p_network_protocols_metrics):
if s.protocol.len > 0:
libp2p_protocols_bytes.inc(bytes.int64, labelValues = [s.protocol, "in"])
libp2p_protocols_bytes.inc(bytes.int64, labelValues=[s.protocol, "in"])
trace "readOnce", s, bytes
if bytes == 0:
await s.closeUnderlying()
return bytes
except CancelledError as exc:
await s.reset()
raise exc
except LPStreamError as exc:
# Resetting is necessary because data has been lost in s.readBuf and
# there's no way to gracefully recover / use the channel any more
except CatchableError as exc:
# readOnce in BufferStream generally raises on EOF or cancellation - for
# the former, resetting is harmless, for the latter it's necessary because
# data has been lost in s.readBuf and there's no way to gracefully recover /
# use the channel any more
await s.reset()
raise newLPStreamConnDownError(exc)
proc prepareWrite(
s: LPChannel, msg: seq[byte]
): Future[void] {.async: (raises: [CancelledError, LPStreamError]).} =
proc prepareWrite(s: LPChannel, msg: seq[byte]): Future[void] {.async.} =
# prepareWrite is the slow path of writing a message - see conditions in
# write
if s.remoteReset:
@@ -214,7 +215,7 @@ proc prepareWrite(
debug "Closing connection, too many in-flight writes on channel",
s, conn = s.conn, writes = s.writes
when defined(libp2p_mplex_metrics):
libp2p_mplex_qlenclose.inc()
libp2p_mplex_qlenclose.inc()
await s.reset()
await s.conn.close()
return
@@ -225,12 +226,10 @@ proc prepareWrite(
await s.conn.writeMsg(s.id, s.msgCode, msg)
proc completeWrite(
s: LPChannel,
fut: Future[void].Raising([CancelledError, LPStreamError]),
msgLen: int,
): Future[void] {.async: (raises: [CancelledError, LPStreamError]).} =
s: LPChannel, fut: Future[void], msgLen: int): Future[void] {.async.} =
try:
s.writes += 1
s.writesBytes += msgLen
when defined(libp2p_mplex_metrics):
libp2p_mplex_qlen.observe(s.writes.int64 - 1)
@@ -239,11 +238,9 @@ proc completeWrite(
else:
await fut
when defined(libp2p_network_protocols_metrics):
when defined(libp2p_network_protocol_metrics):
if s.protocol.len > 0:
# This crashes on Nim 2.0.2 with `--mm:orc` during `nimble test`
# https://github.com/status-im/nim-metrics/issues/79
libp2p_protocols_bytes.inc(msgLen.int64, labelValues = [s.protocol, "out"])
libp2p_protocols_bytes.inc(msgLen.int64, labelValues=[s.protocol, "out"])
s.activity = true
except CancelledError as exc:
@@ -255,21 +252,24 @@ proc completeWrite(
raise exc
except LPStreamEOFError as exc:
raise exc
except LPStreamError as exc:
except CatchableError as exc:
trace "exception in lpchannel write handler", s, msg = exc.msg
await s.reset()
await s.conn.close()
raise newLPStreamConnDownError(exc)
finally:
s.writes -= 1
s.writesBytes -= msgLen
method write*(
s: LPChannel, msg: seq[byte]
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
method queuedSendBytes*(channel: LPChannel): int = channel.writesBytes
method write*(s: LPChannel, msg: seq[byte]): Future[void] =
## Write to mplex channel - there may be up to MaxWrite concurrent writes
## pending after which the peer is disconnected
let closed = s.closedLocal or s.conn.closed
let
closed = s.closedLocal or s.conn.closed
let fut =
if (not closed) and msg.len > 0 and s.writes < MaxWrites and s.isOpen:
@@ -282,17 +282,16 @@ method write*(
s.completeWrite(fut, msg.len)
method getWrapped*(s: LPChannel): Connection =
s.conn
method getWrapped*(s: LPChannel): Connection = s.conn
proc init*(
L: type LPChannel,
id: uint64,
conn: Connection,
initiator: bool,
name: string = "",
timeout: Duration = DefaultChanTimeout,
): LPChannel =
L: type LPChannel,
id: uint64,
conn: Connection,
initiator: bool,
name: string = "",
timeout: Duration = DefaultChanTimeout): LPChannel =
let chann = L(
id: id,
name: name,
@@ -303,17 +302,12 @@ proc init*(
msgCode: if initiator: MessageType.MsgOut else: MessageType.MsgIn,
closeCode: if initiator: MessageType.CloseOut else: MessageType.CloseIn,
resetCode: if initiator: MessageType.ResetOut else: MessageType.ResetIn,
dir: if initiator: Direction.Out else: Direction.In,
)
dir: if initiator: Direction.Out else: Direction.In)
chann.initStream()
when chronicles.enabledLogLevel == LogLevel.TRACE:
chann.name =
if chann.name.len > 0:
chann.name
else:
$chann.oid
chann.name = if chann.name.len > 0: chann.name else: $chann.oid
trace "Created new lpchannel", s = chann, id, initiator

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,18 +7,20 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import tables, sequtils, oids
import chronos, chronicles, stew/byteutils, metrics
import
../muxer,
../../stream/connection,
../../stream/bufferstream,
../../utility,
../../peerinfo,
./coder,
./lpchannel
import ../muxer,
../../stream/connection,
../../stream/bufferstream,
../../utility,
../../peerinfo,
./coder,
./lpchannel
export muxer
@@ -27,10 +29,12 @@ logScope:
const MplexCodec* = "/mplex/6.7.0"
const MaxChannelCount = 200
const
MaxChannelCount = 200
when defined(libp2p_expensive_metrics):
declareGauge(libp2p_mplex_channels, "mplex channels", labels = ["initiator", "peer"])
declareGauge(libp2p_mplex_channels,
"mplex channels", labels = ["initiator", "peer"])
type
InvalidChannelIdError* = object of MuxerError
@@ -47,8 +51,7 @@ type
func shortLog*(m: Mplex): auto =
shortLog(m.connection)
chronicles.formatIt(Mplex):
shortLog(it)
chronicles.formatIt(Mplex): shortLog(it)
proc newTooManyChannels(): ref TooManyChannels =
newException(TooManyChannels, "max allowed channel count exceeded")
@@ -56,7 +59,7 @@ proc newTooManyChannels(): ref TooManyChannels =
proc newInvalidChannelIdError(): ref InvalidChannelIdError =
newException(InvalidChannelIdError, "max allowed channel count exceeded")
proc cleanupChann(m: Mplex, chann: LPChannel) {.async: (raises: []), inline.} =
proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} =
## remove the local channel from the internal tables
##
try:
@@ -67,31 +70,31 @@ proc cleanupChann(m: Mplex, chann: LPChannel) {.async: (raises: []), inline.} =
when defined(libp2p_expensive_metrics):
libp2p_mplex_channels.set(
m.channels[chann.initiator].len.int64,
labelValues = [$chann.initiator, $m.connection.peerId],
)
except CancelledError as exc:
labelValues = [$chann.initiator, $m.connection.peerId])
except CatchableError as exc:
warn "Error cleaning up mplex channel", m, chann, msg = exc.msg
proc newStreamInternal*(
m: Mplex,
initiator: bool = true,
chanId: uint64 = 0,
name: string = "",
timeout: Duration,
): LPChannel {.gcsafe, raises: [InvalidChannelIdError].} =
proc newStreamInternal*(m: Mplex,
initiator: bool = true,
chanId: uint64 = 0,
name: string = "",
timeout: Duration): LPChannel
{.gcsafe, raises: [Defect, InvalidChannelIdError].} =
## create new channel/stream
##
let id =
if initiator:
m.currentId.inc()
m.currentId
else:
chanId
let id = if initiator:
m.currentId.inc(); m.currentId
else: chanId
if id in m.channels[initiator]:
raise newInvalidChannelIdError()
result = LPChannel.init(id, m.connection, initiator, name, timeout = timeout)
result = LPChannel.init(
id,
m.connection,
initiator,
name,
timeout = timeout)
result.peerId = m.connection.peerId
result.observedAddr = m.connection.observedAddr
@@ -108,17 +111,21 @@ proc newStreamInternal*(
when defined(libp2p_expensive_metrics):
libp2p_mplex_channels.set(
m.channels[initiator].len.int64, labelValues = [$initiator, $m.connection.peerId]
)
m.channels[initiator].len.int64,
labelValues = [$initiator, $m.connection.peerId])
proc handleStream(m: Mplex, chann: LPChannel) {.async: (raises: []).} =
proc handleStream(m: Mplex, chann: LPChannel) {.async.} =
## call the muxer stream handler for this channel
##
await m.streamHandler(chann)
trace "finished handling stream", m, chann
doAssert(chann.closed, "connection not closed by handler!")
try:
await m.streamHandler(chann)
trace "finished handling stream", m, chann
doAssert(chann.closed, "connection not closed by handler!")
except CatchableError as exc:
trace "Exception in mplex stream handler", m, chann, msg = exc.msg
await chann.reset()
method handle*(m: Mplex) {.async: (raises: []).} =
method handle*(m: Mplex) {.async, gcsafe.} =
trace "Starting mplex handler", m
try:
while not m.connection.atEof:
@@ -146,7 +153,7 @@ method handle*(m: Mplex) {.async: (raises: []).} =
else:
if m.channels[false].len > m.maxChannCount - 1:
warn "too many channels created by remote peer",
allowedMax = MaxChannelCount, m
allowedMax = MaxChannelCount, m
raise newTooManyChannels()
let name = string.fromBytes(data)
@@ -154,64 +161,59 @@ method handle*(m: Mplex) {.async: (raises: []).} =
trace "Processing channel message", m, channel, data = data.shortLog
case msgType
of MessageType.New:
trace "created channel", m, channel
case msgType:
of MessageType.New:
trace "created channel", m, channel
if m.streamHandler != nil:
# Launch handler task
# All the errors are handled inside `handleStream()` procedure.
asyncSpawn m.handleStream(channel)
of MessageType.MsgIn, MessageType.MsgOut:
if data.len > MaxMsgSize:
warn "attempting to send a packet larger than allowed",
allowed = MaxMsgSize, channel
raise newLPStreamLimitError()
if not isNil(m.streamHandler):
# Launch handler task
# All the errors are handled inside `handleStream()` procedure.
asyncSpawn m.handleStream(channel)
trace "pushing data to channel", m, channel, len = data.len
try:
await channel.pushData(data)
trace "pushed data to channel", m, channel, len = data.len
except LPStreamClosedError as exc:
# Channel is being closed, but `cleanupChann` was not yet triggered.
trace "pushing data to channel failed",
m, channel, len = data.len, msg = exc.msg
discard # Ignore message, same as if `cleanupChann` had completed.
of MessageType.CloseIn, MessageType.CloseOut:
await channel.pushEof()
of MessageType.ResetIn, MessageType.ResetOut:
channel.remoteReset = true
await channel.reset()
of MessageType.MsgIn, MessageType.MsgOut:
if data.len > MaxMsgSize:
warn "attempting to send a packet larger than allowed",
allowed = MaxMsgSize, channel
raise newLPStreamLimitError()
trace "pushing data to channel", m, channel, len = data.len
try:
await channel.pushData(data)
trace "pushed data to channel", m, channel, len = data.len
except LPStreamClosedError as exc:
# Channel is being closed, but `cleanupChann` was not yet triggered.
trace "pushing data to channel failed", m, channel, len = data.len,
msg = exc.msg
discard # Ignore message, same as if `cleanupChann` had completed.
of MessageType.CloseIn, MessageType.CloseOut:
await channel.pushEof()
of MessageType.ResetIn, MessageType.ResetOut:
channel.remoteReset = true
await channel.reset()
except CancelledError:
debug "Unexpected cancellation in mplex handler", m
except LPStreamEOFError as exc:
trace "Stream EOF", m, msg = exc.msg
except LPStreamError as exc:
debug "Unexpected stream exception in mplex read loop", m, msg = exc.msg
except MuxerError as exc:
debug "Unexpected muxer exception in mplex read loop", m, msg = exc.msg
except CatchableError as exc:
debug "Unexpected exception in mplex read loop", m, msg = exc.msg
finally:
await m.close()
trace "Stopped mplex handler", m
proc new*(
M: type Mplex,
conn: Connection,
inTimeout: Duration = DefaultChanTimeout,
outTimeout: Duration = DefaultChanTimeout,
maxChannCount: int = MaxChannelCount,
): Mplex =
M(
connection: conn,
proc new*(M: type Mplex,
conn: Connection,
inTimeout, outTimeout: Duration = DefaultChanTimeout,
maxChannCount: int = MaxChannelCount): Mplex =
M(connection: conn,
inChannTimeout: inTimeout,
outChannTimeout: outTimeout,
oid: genOid(),
maxChannCount: maxChannCount,
)
maxChannCount: maxChannCount)
method newStream*(
m: Mplex, name: string = "", lazy: bool = false
): Future[Connection] {.async: (raises: [CancelledError, LPStreamError, MuxerError]).} =
method newStream*(m: Mplex,
name: string = "",
lazy: bool = false): Future[Connection] {.async, gcsafe.} =
let channel = m.newStreamInternal(timeout = m.inChannTimeout)
if not lazy:
@@ -219,7 +221,7 @@ method newStream*(
return Connection(channel)
method close*(m: Mplex) {.async: (raises: []).} =
method close*(m: Mplex) {.async, gcsafe.} =
if m.isClosed:
trace "Already closed", m
return
@@ -246,9 +248,3 @@ method close*(m: Mplex) {.async: (raises: []).} =
m.channels[true].clear()
trace "Closed mplex", m
method getStreams*(m: Mplex): seq[Connection] =
for c in m.channels[false].values:
result.add(c)
for c in m.channels[true].values:
result.add(c)

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,65 +7,86 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos, chronicles
import ../stream/connection, ../errors
import ../protocols/protocol,
../stream/connection,
../errors
logScope:
topics = "libp2p muxer"
const DefaultChanTimeout* = 5.minutes
const
DefaultChanTimeout* = 5.minutes
type
MuxerError* = object of LPError
TooManyChannels* = object of MuxerError
StreamHandler* = proc(conn: Connection): Future[void] {.async: (raises: []).}
MuxerHandler* = proc(muxer: Muxer): Future[void] {.async: (raises: []).}
StreamHandler* = proc(conn: Connection): Future[void] {.gcsafe, raises: [Defect].}
MuxerHandler* = proc(muxer: Muxer): Future[void] {.gcsafe, raises: [Defect].}
Muxer* = ref object of RootObj
streamHandler*: StreamHandler
handler*: Future[void].Raising([])
connection*: Connection
# user provider proc that returns a constructed Muxer
MuxerConstructor* = proc(conn: Connection): Muxer {.gcsafe, closure, raises: [].}
MuxerConstructor* = proc(conn: Connection): Muxer {.gcsafe, closure, raises: [Defect].}
# this wraps a creator proc that knows how to make muxers
MuxerProvider* = object
MuxerProvider* = ref object of LPProtocol
newMuxer*: MuxerConstructor
codec*: string
streamHandler*: StreamHandler # triggered every time there is a new stream, called for any muxer instance
muxerHandler*: MuxerHandler # triggered every time there is a new muxed connection created
func shortLog*(m: Muxer): auto =
if m == nil:
"nil"
else:
shortLog(m.connection)
chronicles.formatIt(Muxer):
shortLog(it)
func shortLog*(m: Muxer): auto = shortLog(m.connection)
chronicles.formatIt(Muxer): shortLog(it)
# muxer interface
method newStream*(
m: Muxer, name: string = "", lazy: bool = false
): Future[Connection] {.
base, async: (raises: [CancelledError, LPStreamError, MuxerError], raw: true)
.} =
raiseAssert("Not implemented!")
method close*(m: Muxer) {.base, async: (raises: []).} =
if m.connection != nil:
await m.connection.close()
method handle*(m: Muxer): Future[void] {.base, async: (raises: []).} =
discard
method newStream*(m: Muxer, name: string = "", lazy: bool = false):
Future[Connection] {.base, async, gcsafe.} = discard
method close*(m: Muxer) {.base, async, gcsafe.} = discard
method handle*(m: Muxer): Future[void] {.base, async, gcsafe.} = discard
proc new*(
T: typedesc[MuxerProvider], creator: MuxerConstructor, codec: string
): T {.gcsafe.} =
let muxerProvider = T(newMuxer: creator, codec: codec)
T: typedesc[MuxerProvider],
creator: MuxerConstructor,
codec: string): T {.gcsafe.} =
let muxerProvider = T(newMuxer: creator)
muxerProvider.codec = codec
muxerProvider.init()
muxerProvider
method getStreams*(m: Muxer): seq[Connection] {.base.} =
raiseAssert("Not implemented!")
method init(c: MuxerProvider) =
proc handler(conn: Connection, proto: string) {.async, gcsafe, closure.} =
trace "starting muxer handler", proto=proto, conn
try:
let
muxer = c.newMuxer(conn)
if not isNil(c.streamHandler):
muxer.streamHandler = c.streamHandler
var futs = newSeq[Future[void]]()
futs &= muxer.handle()
# finally await both the futures
if not isNil(c.muxerHandler):
await c.muxerHandler(muxer)
when defined(libp2p_agents_metrics):
conn.shortAgent = muxer.connection.shortAgent
checkFutures(await allFinished(futs))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception in muxer handler", exc = exc.msg, conn, proto
finally:
await conn.close()
c.handler = handler

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,11 +7,15 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import sequtils, std/[tables]
import chronos, chronicles, metrics, stew/[endians2, byteutils, objects]
import ../muxer, ../../stream/connection
import ../muxer,
../../stream/connection
export muxer
@@ -21,21 +25,18 @@ logScope:
const
YamuxCodec* = "/yamux/1.0.0"
YamuxVersion = 0.uint8
YamuxDefaultWindowSize* = 256000
MaxSendQueueSize = 256000
DefaultWindowSize = 256000
MaxChannelCount = 200
when defined(libp2p_yamux_metrics):
declareGauge libp2p_yamux_channels, "yamux channels", labels = ["initiator", "peer"]
declareHistogram libp2p_yamux_send_queue,
"message send queue length (in byte)",
buckets = [0.0, 100.0, 250.0, 1000.0, 2000.0, 3200.0, 6400.0, 25600.0, 256000.0]
declareHistogram libp2p_yamux_recv_queue,
"message recv queue length (in byte)",
buckets = [0.0, 100.0, 250.0, 1000.0, 2000.0, 3200.0, 6400.0, 25600.0, 256000.0]
declareGauge(libp2p_yamux_channels, "yamux channels", labels = ["initiator", "peer"])
declareHistogram libp2p_yamux_send_queue, "message send queue length (in byte)",
buckets = [0.0, 100.0, 250.0, 1000.0, 2000.0, 1600.0, 6400.0, 25600.0, 256000.0]
declareHistogram libp2p_yamux_recv_queue, "message recv queue length (in byte)",
buckets = [0.0, 100.0, 250.0, 1000.0, 2000.0, 1600.0, 6400.0, 25600.0, 256000.0]
type
YamuxError* = object of MuxerError
YamuxError* = object of CatchableError
MsgType = enum
Data = 0x0
@@ -50,9 +51,9 @@ type
Rst
GoAwayStatus = enum
NormalTermination = 0x0
ProtocolError = 0x1
InternalError = 0x2
NormalTermination = 0x0,
ProtocolError = 0x1,
InternalError = 0x2,
YamuxHeader = object
version: uint8
@@ -61,92 +62,90 @@ type
streamId: uint32
length: uint32
proc readHeader(
conn: LPStream
): Future[YamuxHeader] {.async: (raises: [CancelledError, LPStreamError, MuxerError]).} =
proc readHeader(conn: LPStream): Future[YamuxHeader] {.async, gcsafe.} =
var buffer: array[12, byte]
await conn.readExactly(addr buffer[0], 12)
result.version = buffer[0]
let flags = fromBytesBE(uint16, buffer[2 .. 3])
if not result.msgType.checkedEnumAssign(buffer[1]) or flags notin 0'u16 .. 15'u16:
raise newException(YamuxError, "Wrong header")
let flags = fromBytesBE(uint16, buffer[2..3])
if not result.msgType.checkedEnumAssign(buffer[1]) or flags notin 0'u16..15'u16:
raise newException(YamuxError, "Wrong header")
result.flags = cast[set[MsgFlags]](flags)
result.streamId = fromBytesBE(uint32, buffer[4 .. 7])
result.length = fromBytesBE(uint32, buffer[8 .. 11])
result.streamId = fromBytesBE(uint32, buffer[4..7])
result.length = fromBytesBE(uint32, buffer[8..11])
return result
proc `$`(header: YamuxHeader): string =
"{" & $header.msgType & ", " & "{" &
header.flags.foldl(
if a != "":
a & ", " & $b
else:
$b
,
"",
) & "}, " & "streamId: " & $header.streamId & ", " & "length: " & $header.length &
"}"
result = "{" & $header.msgType & ", "
result &= "{" & header.flags.foldl(if a != "": a & ", " & $b else: $b, "") & "}, "
result &= "streamId: " & $header.streamId & ", "
result &= "length: " & $header.length & "}"
proc encode(header: YamuxHeader): array[12, byte] =
result[0] = header.version
result[1] = uint8(header.msgType)
result[2 .. 3] = toBytesBE(uint16(cast[uint8](header.flags)))
# workaround https://github.com/nim-lang/Nim/issues/21789
result[4 .. 7] = toBytesBE(header.streamId)
result[8 .. 11] = toBytesBE(header.length)
result[2..3] = toBytesBE(cast[uint16](header.flags))
result[4..7] = toBytesBE(header.streamId)
result[8..11] = toBytesBE(header.length)
proc write(
conn: LPStream, header: YamuxHeader
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
proc write(conn: LPStream, header: YamuxHeader): Future[void] {.gcsafe.} =
trace "write directly on stream", h = $header
var buffer = header.encode()
conn.write(@buffer)
return conn.write(@buffer)
proc ping(T: type[YamuxHeader], flag: MsgFlags, pingData: uint32): T =
T(version: YamuxVersion, msgType: MsgType.Ping, flags: {flag}, length: pingData)
T(
version: YamuxVersion,
msgType: MsgType.Ping,
flags: {flag},
length: pingData
)
proc goAway(T: type[YamuxHeader], status: GoAwayStatus): T =
T(version: YamuxVersion, msgType: MsgType.GoAway, length: uint32(status))
T(
version: YamuxVersion,
msgType: MsgType.GoAway,
length: uint32(status)
)
proc data(
T: type[YamuxHeader],
streamId: uint32,
length: uint32 = 0,
flags: set[MsgFlags] = {},
): T =
T: type[YamuxHeader],
streamId: uint32,
length: uint32 = 0,
flags: set[MsgFlags] = {},
): T =
T(
version: YamuxVersion,
msgType: MsgType.Data,
length: length,
flags: flags,
streamId: streamId,
streamId: streamId
)
proc windowUpdate(
T: type[YamuxHeader], streamId: uint32, delta: uint32, flags: set[MsgFlags] = {}
): T =
T: type[YamuxHeader],
streamId: uint32,
delta: uint32,
flags: set[MsgFlags] = {},
): T =
T(
version: YamuxVersion,
msgType: MsgType.WindowUpdate,
length: delta,
flags: flags,
streamId: streamId,
streamId: streamId
)
type
ToSend =
tuple[
data: seq[byte],
sent: int,
fut: Future[void].Raising([CancelledError, LPStreamError]),
]
ToSend = tuple
data: seq[byte]
sent: int
fut: Future[void]
YamuxChannel* = ref object of Connection
id: uint32
recvWindow: int
sendWindow: int
maxRecvWindow: int
maxSendQueueSize: int
conn: Connection
isSrc: bool
opened: bool
@@ -155,73 +154,49 @@ type
recvQueue: seq[byte]
isReset: bool
remoteReset: bool
closedRemotely: AsyncEvent
closedRemotely: Future[void]
closedLocally: bool
receivedData: AsyncEvent
returnedEof: bool
proc `$`(channel: YamuxChannel): string =
result = if channel.conn.dir == Out: "=> " else: "<= "
result &= $channel.id
var s: seq[string] = @[]
if channel.closedRemotely.isSet():
if channel.closedRemotely.done():
s.add("ClosedRemotely")
if channel.closedLocally:
s.add("ClosedLocally")
if channel.isReset:
s.add("Reset")
if s.len > 0:
result &=
" {" &
s.foldl(
if a != "":
a & ", " & b
else:
b
,
"",
) & "}"
result &= " {" & s.foldl(if a != "": a & ", " & b else: b, "") & "}"
proc lengthSendQueue(channel: YamuxChannel): int =
## Returns the length of what remains to be sent
##
channel.sendQueue.foldl(a + b.data.len - b.sent, 0)
proc sendQueueBytes(channel: YamuxChannel, limit: bool = false): int =
for (elem, sent, _) in channel.sendQueue:
result.inc(min(elem.len - sent, if limit: channel.maxRecvWindow div 3 else: elem.len - sent))
proc lengthSendQueueWithLimit(channel: YamuxChannel): int =
## Returns the length of what remains to be sent, but limit the size of big messages.
##
# For leniency, limit big messages size to the third of maxSendQueueSize
# This value is arbitrary, it's not in the specs, it permits to store up to
# 3 big messages if the peer is stalling.
channel.sendQueue.foldl(
a + min(b.data.len - b.sent, channel.maxSendQueueSize div 3), 0
)
method queuedSendBytes*(channel: YamuxChannel): int = channel.sendQueueBytes()
proc actuallyClose(channel: YamuxChannel) {.async: (raises: []).} =
proc actuallyClose(channel: YamuxChannel) {.async.} =
if channel.closedLocally and channel.sendQueue.len == 0 and
channel.closedRemotely.isSet():
channel.closedRemotely.done():
await procCall Connection(channel).closeImpl()
proc remoteClosed(channel: YamuxChannel) {.async: (raises: []).} =
if not channel.closedRemotely.isSet():
channel.closedRemotely.fire()
proc remoteClosed(channel: YamuxChannel) {.async.} =
if not channel.closedRemotely.done():
channel.closedRemotely.complete()
await channel.actuallyClose()
method closeImpl*(channel: YamuxChannel) {.async: (raises: []).} =
method closeImpl*(channel: YamuxChannel) {.async, gcsafe.} =
if not channel.closedLocally:
trace "Closing yamux channel locally", streamId = channel.id, conn = channel.conn
channel.closedLocally = true
if not channel.isReset and channel.sendQueue.len == 0:
try:
await channel.conn.write(YamuxHeader.data(channel.id, 0, {Fin}))
except CancelledError, LPStreamError:
discard
if channel.isReset == false and channel.sendQueue.len == 0:
await channel.conn.write(YamuxHeader.data(channel.id, 0, {Fin}))
await channel.actuallyClose()
proc reset(channel: YamuxChannel, isLocal: bool = false) {.async: (raises: []).} =
# If we reset locally, we want to flush up to a maximum of recvWindow
# bytes. It's because the peer we're connected to can send us data before
# it receives the reset.
proc reset(channel: YamuxChannel, isLocal: bool = false) {.async.} =
if channel.isReset:
return
trace "Reset channel"
@@ -233,76 +208,66 @@ proc reset(channel: YamuxChannel, isLocal: bool = false) {.async: (raises: []).}
channel.recvQueue = @[]
channel.sendWindow = 0
if not channel.closedLocally:
if isLocal and not channel.isSending:
try:
await channel.conn.write(YamuxHeader.data(channel.id, 0, {Rst}))
except CancelledError, LPStreamError:
discard
if isLocal:
try: await channel.conn.write(YamuxHeader.data(channel.id, 0, {Rst}))
except LPStreamEOFError as exc: discard
except LPStreamClosedError as exc: discard
await channel.close()
if not channel.closedRemotely.isSet():
if not channel.closedRemotely.done():
await channel.remoteClosed()
channel.receivedData.fire()
if not isLocal:
# If the reset is remote, there's no reason to flush anything.
# If we reset locally, we want to flush up to a maximum of recvWindow
# bytes. We use the recvWindow in the proc cleanupChann.
channel.recvWindow = 0
proc updateRecvWindow(
channel: YamuxChannel
) {.async: (raises: [CancelledError, LPStreamError]).} =
## Send to the peer a window update when the recvWindow is empty enough
##
# In order to avoid spamming a window update everytime a byte is read,
# we send it everytime half of the maxRecvWindow is read.
proc updateRecvWindow(channel: YamuxChannel) {.async.} =
let inWindow = channel.recvWindow + channel.recvQueue.len
if inWindow > channel.maxRecvWindow div 2:
return
let delta = channel.maxRecvWindow - inWindow
channel.recvWindow.inc(delta)
await channel.conn.write(YamuxHeader.windowUpdate(channel.id, delta.uint32))
await channel.conn.write(YamuxHeader.windowUpdate(
channel.id,
delta.uint32
))
trace "increasing the recvWindow", delta
method readOnce*(
channel: YamuxChannel, pbytes: pointer, nbytes: int
): Future[int] {.async: (raises: [CancelledError, LPStreamError]).} =
## Read from a yamux channel
channel: YamuxChannel,
pbytes: pointer,
nbytes: int):
Future[int] {.async.} =
if channel.isReset:
raise
if channel.remoteReset:
raise if channel.remoteReset:
newLPStreamResetError()
elif channel.closedLocally:
newLPStreamClosedError()
else:
newLPStreamConnDownError()
if channel.isEof:
if channel.returnedEof:
raise newLPStreamRemoteClosedError()
if channel.recvQueue.len == 0:
channel.receivedData.clear()
try: # https://github.com/status-im/nim-chronos/issues/516
discard await race(channel.closedRemotely.wait(), channel.receivedData.wait())
except ValueError:
raiseAssert("Futures list is not empty")
if channel.closedRemotely.isSet() and channel.recvQueue.len == 0:
channel.isEof = true
return
0 # we return 0 to indicate that the channel is closed for reading from now on
await channel.closedRemotely or channel.receivedData.wait()
if channel.closedRemotely.done() and channel.recvQueue.len == 0:
channel.returnedEof = true
return 0
let toRead = min(channel.recvQueue.len, nbytes)
var p = cast[ptr UncheckedArray[byte]](pbytes)
toOpenArray(p, 0, nbytes - 1)[0 ..< toRead] =
channel.recvQueue.toOpenArray(0, toRead - 1)
channel.recvQueue = channel.recvQueue[toRead ..^ 1]
toOpenArray(p, 0, nbytes - 1)[0..<toRead] = channel.recvQueue.toOpenArray(0, toRead - 1)
channel.recvQueue = channel.recvQueue[toRead..^1]
# We made some room in the recv buffer let the peer know
await channel.updateRecvWindow()
channel.activity = true
return toRead
proc gotDataFromRemote(
channel: YamuxChannel, b: seq[byte]
) {.async: (raises: [CancelledError, LPStreamError]).} =
proc gotDataFromRemote(channel: YamuxChannel, b: seq[byte]) {.async.} =
channel.recvWindow -= b.len
channel.recvQueue = channel.recvQueue.concat(b)
channel.receivedData.fire()
@@ -313,28 +278,26 @@ proc gotDataFromRemote(
proc setMaxRecvWindow*(channel: YamuxChannel, maxRecvWindow: int) =
channel.maxRecvWindow = maxRecvWindow
proc trySend(
channel: YamuxChannel
) {.async: (raises: [CancelledError, LPStreamError]).} =
proc trySend(channel: YamuxChannel) {.async.} =
if channel.isSending:
return
channel.isSending = true
defer:
channel.isSending = false
defer: channel.isSending = false
while channel.sendQueue.len != 0:
channel.sendQueue.keepItIf(not (it.fut.cancelled() and it.sent == 0))
if channel.sendWindow == 0:
trace "trying to send while the sendWindow is empty"
if channel.lengthSendQueueWithLimit() > channel.maxSendQueueSize:
trace "channel send queue too big, resetting",
maxSendQueueSize = channel.maxSendQueueSize,
currentQueueSize = channel.lengthSendQueueWithLimit()
await channel.reset(isLocal = true)
trace "send window empty"
if channel.sendQueueBytes(true) > channel.maxRecvWindow:
debug "channel send queue too big, resetting", maxSendWindow=channel.maxRecvWindow,
currentQueueSize = channel.sendQueueBytes(true)
try:
await channel.reset(true)
except CatchableError as exc:
debug "failed to reset", msg=exc.msg
break
let
bytesAvailable = channel.lengthSendQueue()
bytesAvailable = channel.sendQueueBytes()
toSend = min(channel.sendWindow, bytesAvailable)
var
sendBuffer = newSeqUninitialized[byte](toSend + 12)
@@ -345,37 +308,24 @@ proc trySend(
trace "last buffer we'll sent on this channel", toSend, bytesAvailable
header.flags.incl({Fin})
sendBuffer[0 ..< 12] = header.encode()
sendBuffer[0..<12] = header.encode()
var futures: seq[Future[void].Raising([CancelledError, LPStreamError])]
var futures: seq[Future[void]]
while inBuffer < toSend:
# concatenate the different message we try to send into one buffer
let (data, sent, fut) = channel.sendQueue[0]
let bufferToSend = min(data.len - sent, toSend - inBuffer)
sendBuffer.toOpenArray(12, 12 + toSend - 1)[
inBuffer ..< (inBuffer + bufferToSend)
] = channel.sendQueue[0].data.toOpenArray(sent, sent + bufferToSend - 1)
sendBuffer.toOpenArray(12, 12 + toSend - 1)[inBuffer..<(inBuffer+bufferToSend)] =
channel.sendQueue[0].data.toOpenArray(sent, sent + bufferToSend - 1)
channel.sendQueue[0].sent.inc(bufferToSend)
if channel.sendQueue[0].sent >= data.len:
# if every byte of the message is in the buffer, add the write future to the
# sequence of futures to be completed (or failed) when the buffer is sent
futures.add(fut)
channel.sendQueue.delete(0)
inBuffer.inc(bufferToSend)
trace "try to send the buffer", h = $header
trace "build send buffer", h = $header, msg=string.fromBytes(sendBuffer[12..^1])
channel.sendWindow.dec(toSend)
try:
await channel.conn.write(sendBuffer)
except CancelledError:
trace "cancelled sending the buffer"
for fut in futures.items():
fut.cancelSoon()
await channel.reset()
break
except LPStreamError as exc:
trace "failed to send the buffer"
try: await channel.conn.write(sendBuffer)
except CatchableError as exc:
let connDown = newLPStreamConnDownError(exc)
for fut in futures.items():
fut.fail(connDown)
@@ -385,11 +335,7 @@ proc trySend(
fut.complete()
channel.activity = true
method write*(
channel: YamuxChannel, msg: seq[byte]
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
## Write to yamux channel
##
method write*(channel: YamuxChannel, msg: seq[byte]): Future[void] =
result = newFuture[void]("Yamux Send")
if channel.remoteReset:
result.fail(newLPStreamResetError())
@@ -402,100 +348,65 @@ method write*(
return result
channel.sendQueue.add((msg, 0, result))
when defined(libp2p_yamux_metrics):
libp2p_yamux_send_queue.observe(channel.lengthSendQueue().int64)
libp2p_yamux_recv_queue.observe(channel.sendQueueBytes().int64)
asyncSpawn channel.trySend()
proc open(channel: YamuxChannel) {.async: (raises: [CancelledError, LPStreamError]).} =
## Open a yamux channel by sending a window update with Syn or Ack flag
##
proc open*(channel: YamuxChannel) {.async, gcsafe.} =
if channel.opened:
trace "Try to open channel twice"
return
channel.opened = true
await channel.conn.write(
YamuxHeader.windowUpdate(
channel.id,
uint32(max(channel.maxRecvWindow - YamuxDefaultWindowSize, 0)),
{if channel.isSrc: Syn else: Ack},
)
)
await channel.conn.write(YamuxHeader.data(channel.id, 0, {if channel.isSrc: Syn else: Ack}))
method getWrapped*(channel: YamuxChannel): Connection =
channel.conn
type Yamux* = ref object of Muxer
channels: Table[uint32, YamuxChannel]
flushed: Table[uint32, int]
currentId: uint32
isClosed: bool
maxChannCount: int
windowSize: int
maxSendQueueSize: int
inTimeout: Duration
outTimeout: Duration
type
Yamux* = ref object of Muxer
channels: Table[uint32, YamuxChannel]
flushed: Table[uint32, int]
currentId: uint32
isClosed: bool
maxChannCount: int
proc lenBySrc(m: Yamux, isSrc: bool): int =
for v in m.channels.values():
if v.isSrc == isSrc:
result += 1
if v.isSrc == isSrc: result += 1
proc cleanupChannel(m: Yamux, channel: YamuxChannel) {.async: (raises: []).} =
try:
await channel.join()
except CancelledError:
discard
proc cleanupChann(m: Yamux, channel: YamuxChannel) {.async.} =
await channel.join()
m.channels.del(channel.id)
when defined(libp2p_yamux_metrics):
libp2p_yamux_channels.set(
m.lenBySrc(channel.isSrc).int64, [$channel.isSrc, $channel.peerId]
)
libp2p_yamux_channels.set(m.lenBySrc(channel.isSrc).int64, [$channel.isSrc, $channel.peerId])
if channel.isReset and channel.recvWindow > 0:
m.flushed[channel.id] = channel.recvWindow
proc createStream(
m: Yamux, id: uint32, isSrc: bool, recvWindow: int, maxSendQueueSize: int
): YamuxChannel =
# During initialization, recvWindow can be larger than maxRecvWindow.
# This is because the peer we're connected to will always assume
# that the initial recvWindow is 256k.
# To solve this contradiction, no updateWindow will be sent until
# recvWindow is less than maxRecvWindow
var stream = YamuxChannel(
proc createStream(m: Yamux, id: uint32, isSrc: bool): YamuxChannel =
result = YamuxChannel(
id: id,
maxRecvWindow: recvWindow,
recvWindow:
if recvWindow > YamuxDefaultWindowSize: recvWindow else: YamuxDefaultWindowSize,
sendWindow: YamuxDefaultWindowSize,
maxSendQueueSize: maxSendQueueSize,
maxRecvWindow: DefaultWindowSize,
recvWindow: DefaultWindowSize,
sendWindow: DefaultWindowSize,
isSrc: isSrc,
conn: m.connection,
receivedData: newAsyncEvent(),
closedRemotely: newAsyncEvent(),
closedRemotely: newFuture[void]()
)
stream.objName = "YamuxStream"
if isSrc:
stream.dir = Direction.Out
stream.timeout = m.outTimeout
else:
stream.dir = Direction.In
stream.timeout = m.inTimeout
stream.timeoutHandler = proc(): Future[void] {.async: (raises: [], raw: true).} =
result.objName = "YamuxStream"
result.dir = if isSrc: Direction.Out else: Direction.In
result.timeoutHandler = proc(): Future[void] {.gcsafe.} =
trace "Idle timeout expired, resetting YamuxChannel"
stream.reset(isLocal = true)
stream.initStream()
stream.peerId = m.connection.peerId
stream.observedAddr = m.connection.observedAddr
stream.transportDir = m.connection.transportDir
result.reset()
result.initStream()
result.peerId = m.connection.peerId
result.observedAddr = m.connection.observedAddr
result.transportDir = m.connection.transportDir
when defined(libp2p_agents_metrics):
stream.shortAgent = m.connection.shortAgent
m.channels[id] = stream
asyncSpawn m.cleanupChannel(stream)
trace "created channel", id, pid = m.connection.peerId
result.shortAgent = m.connection.shortAgent
m.channels[id] = result
asyncSpawn m.cleanupChann(result)
trace "created channel", id, pid=m.connection.peerId
when defined(libp2p_yamux_metrics):
libp2p_yamux_channels.set(m.lenBySrc(isSrc).int64, [$isSrc, $stream.peerId])
return stream
libp2p_yamux_channels.set(m.lenBySrc(isSrc).int64, [$isSrc, $result.peerId])
method close*(m: Yamux) {.async: (raises: []).} =
method close*(m: Yamux) {.async.} =
if m.isClosed == true:
trace "Already closed"
return
@@ -504,86 +415,69 @@ method close*(m: Yamux) {.async: (raises: []).} =
trace "Closing yamux"
let channels = toSeq(m.channels.values())
for channel in channels:
await channel.reset(isLocal = true)
try:
await m.connection.write(YamuxHeader.goAway(NormalTermination))
except CancelledError as exc:
trace "cancelled sending goAway", msg = exc.msg
except LPStreamError as exc:
trace "failed to send goAway", msg = exc.msg
await channel.reset(true)
try: await m.connection.write(YamuxHeader.goAway(NormalTermination))
except CatchableError as exc: trace "failed to send goAway", msg=exc.msg
await m.connection.close()
trace "Closed yamux"
proc handleStream(m: Yamux, channel: YamuxChannel) {.async: (raises: []).} =
## Call the muxer stream handler for this channel
proc handleStream(m: Yamux, channel: YamuxChannel) {.async.} =
## call the muxer stream handler for this channel
##
await m.streamHandler(channel)
trace "finished handling stream"
doAssert(channel.isClosed, "connection not closed by handler!")
try:
await m.streamHandler(channel)
trace "finished handling stream"
doAssert(channel.isClosed, "connection not closed by handler!")
except CatchableError as exc:
trace "Exception in yamux stream handler", msg = exc.msg
await channel.reset()
method handle*(m: Yamux) {.async: (raises: []).} =
trace "Starting yamux handler", pid = m.connection.peerId
method handle*(m: Yamux) {.async, gcsafe.} =
trace "Starting yamux handler", pid=m.connection.peerId
try:
while not m.connection.atEof:
trace "waiting for header"
let header = await m.connection.readHeader()
trace "got message", h = $header
case header.msgType
case header.msgType:
of Ping:
if MsgFlags.Syn in header.flags:
await m.connection.write(YamuxHeader.ping(MsgFlags.Ack, header.length))
of GoAway:
var status: GoAwayStatus
if status.checkedEnumAssign(header.length):
trace "Received go away", status
else:
trace "Received unexpected error go away"
if status.checkedEnumAssign(header.length): trace "Received go away", status
else: trace "Received unexpected error go away"
break
of Data, WindowUpdate:
if MsgFlags.Syn in header.flags:
if header.streamId in m.channels:
debug "Trying to create an existing channel, skipping", id = header.streamId
debug "Trying to create an existing channel, skipping", id=header.streamId
else:
if header.streamId in m.flushed:
m.flushed.del(header.streamId)
if header.streamId mod 2 == m.currentId mod 2:
debug "Peer used our reserved stream id, skipping",
id = header.streamId,
currentId = m.currentId,
peerId = m.connection.peerId
raise newException(YamuxError, "Peer used our reserved stream id")
let newStream =
m.createStream(header.streamId, false, m.windowSize, m.maxSendQueueSize)
let newStream = m.createStream(header.streamId, false)
if m.channels.len >= m.maxChannCount:
await newStream.reset()
continue
await newStream.open()
asyncSpawn m.handleStream(newStream)
elif header.streamId notin m.channels:
# Flush the data
m.flushed.withValue(header.streamId, flushed):
if header.msgType == Data:
flushed[].dec(int(header.length))
if flushed[] < 0:
raise
newException(YamuxError, "Peer exhausted the recvWindow after reset")
if header.length > 0:
var buffer = newSeqUninitialized[byte](header.length)
await m.connection.readExactly(addr buffer[0], int(header.length))
do:
if header.streamId notin m.flushed:
raise newException(YamuxError, "Unknown stream ID: " & $header.streamId)
elif header.msgType == Data:
# Flush the data
m.flushed[header.streamId].dec(int(header.length))
if m.flushed[header.streamId] < 0:
raise newException(YamuxError, "Peer exhausted the recvWindow after reset")
if header.length > 0:
var buffer = newSeqUninitialized[byte](header.length)
await m.connection.readExactly(addr buffer[0], int(header.length))
continue
let channel =
try:
m.channels[header.streamId]
except KeyError:
raise newException(
YamuxError,
"Stream was cleaned up before handling data: " & $header.streamId,
)
let channel = m.channels[header.streamId]
if header.msgType == WindowUpdate:
channel.sendWindow += int(header.length)
@@ -596,7 +490,7 @@ method handle*(m: Yamux) {.async: (raises: []).} =
if header.length > 0:
var buffer = newSeqUninitialized[byte](header.length)
await m.connection.readExactly(addr buffer[0], int(header.length))
trace "Msg Rcv", msg = shortLog(buffer)
trace "Msg Rcv", msg=string.fromBytes(buffer)
await channel.gotDataFromRemote(buffer)
if MsgFlags.Fin in header.flags:
@@ -605,58 +499,31 @@ method handle*(m: Yamux) {.async: (raises: []).} =
if MsgFlags.Rst in header.flags:
trace "remote reset channel"
await channel.reset()
except CancelledError as exc:
debug "Unexpected cancellation in yamux handler", msg = exc.msg
except LPStreamEOFError as exc:
trace "Stream EOF", msg = exc.msg
except LPStreamError as exc:
debug "Unexpected stream exception in yamux read loop", msg = exc.msg
except YamuxError as exc:
trace "Closing yamux connection", error = exc.msg
try:
await m.connection.write(YamuxHeader.goAway(ProtocolError))
except CancelledError, LPStreamError:
discard
except MuxerError as exc:
debug "Unexpected muxer exception in yamux read loop", msg = exc.msg
try:
await m.connection.write(YamuxHeader.goAway(ProtocolError))
except CancelledError, LPStreamError:
discard
trace "Closing yamux connection", error=exc.msg
await m.connection.write(YamuxHeader.goAway(ProtocolError))
finally:
await m.close()
trace "Stopped yamux handler"
method getStreams*(m: Yamux): seq[Connection] =
for c in m.channels.values:
result.add(c)
method newStream*(
m: Yamux, name: string = "", lazy: bool = false
): Future[Connection] {.async: (raises: [CancelledError, LPStreamError, MuxerError]).} =
m: Yamux,
name: string = "",
lazy: bool = false): Future[Connection] {.async, gcsafe.} =
if m.channels.len > m.maxChannCount - 1:
raise newException(TooManyChannels, "max allowed channel count exceeded")
let stream = m.createStream(m.currentId, true, m.windowSize, m.maxSendQueueSize)
let stream = m.createStream(m.currentId, true)
m.currentId += 2
if not lazy:
await stream.open()
return stream
proc new*(
T: type[Yamux],
conn: Connection,
maxChannCount: int = MaxChannelCount,
windowSize: int = YamuxDefaultWindowSize,
maxSendQueueSize: int = MaxSendQueueSize,
inTimeout: Duration = 5.minutes,
outTimeout: Duration = 5.minutes,
): T =
proc new*(T: type[Yamux], conn: Connection, maxChannCount: int = MaxChannelCount): T =
T(
connection: conn,
currentId: if conn.dir == Out: 1 else: 2,
maxChannCount: maxChannCount,
windowSize: windowSize,
maxSendQueueSize: maxSendQueueSize,
inTimeout: inTimeout,
outTimeout: outTimeout,
maxChannCount: maxChannCount
)

View File

@@ -7,23 +7,25 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import
std/[streams, strutils, sets, sequtils],
chronos,
chronicles,
stew/byteutils,
dnsclientpkg/[protocol, types],
../utility
chronos, chronicles, stew/byteutils,
dnsclientpkg/[protocol, types]
import nameresolver
import
nameresolver
logScope:
topics = "libp2p dnsresolver"
type DnsResolver* = ref object of NameResolver
nameServers*: seq[TransportAddress]
type
DnsResolver* = ref object of NameResolver
nameServers*: seq[TransportAddress]
proc questionToBuf(address: string, kind: QKind): seq[byte] =
try:
@@ -45,8 +47,10 @@ proc questionToBuf(address: string, kind: QKind): seq[byte] =
return newSeq[byte](0)
proc getDnsResponse(
dnsServer: TransportAddress, address: string, kind: QKind
): Future[Response] {.async.} =
dnsServer: TransportAddress,
address: string,
kind: QKind): Future[Response] {.async.} =
var sendBuf = questionToBuf(address, kind)
if sendBuf.len == 0:
@@ -54,10 +58,9 @@ proc getDnsResponse(
let receivedDataFuture = newFuture[void]()
proc datagramDataReceived(
transp: DatagramTransport, raddr: TransportAddress
): Future[void] {.async, closure.} =
receivedDataFuture.complete()
proc datagramDataReceived(transp: DatagramTransport,
raddr: TransportAddress): Future[void] {.async, closure.} =
receivedDataFuture.complete()
let sock =
if dnsServer.family == AddressFamily.IPv6:
@@ -77,14 +80,18 @@ proc getDnsResponse(
# parseResponse can has a raises: [Exception, ..] because of
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
return exceptionToAssert:
parseResponse(string.fromBytes(rawResponse))
return parseResponse(string.fromBytes(rawResponse))
except CatchableError as exc: raise exc
except Exception as exc: raiseAssert exc.msg
finally:
await sock.closeWait()
method resolveIp*(
self: DnsResolver, address: string, port: Port, domain: Domain = Domain.AF_UNSPEC
): Future[seq[TransportAddress]] {.async.} =
self: DnsResolver,
address: string,
port: Port,
domain: Domain = Domain.AF_UNSPEC): Future[seq[TransportAddress]] {.async.} =
trace "Resolving IP using DNS", address, servers = self.nameServers.mapIt($it), domain
for _ in 0 ..< self.nameServers.len:
let server = self.nameServers[0]
@@ -110,14 +117,18 @@ method resolveIp*(
# toString can has a raises: [Exception, ..] because of
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
resolvedAddresses.incl(exceptionToAssert(answer.toString()))
resolvedAddresses.incl(
try: answer.toString()
except CatchableError as exc: raise exc
except Exception as exc: raiseAssert exc.msg
)
except CancelledError as e:
raise e
except ValueError as e:
info "Invalid DNS query", address, error = e.msg
info "Invalid DNS query", address, error=e.msg
return @[]
except CatchableError as e:
info "Failed to query DNS", address, error = e.msg
info "Failed to query DNS", address, error=e.msg
resolveFailed = true
break
@@ -132,29 +143,34 @@ method resolveIp*(
debug "Failed to resolve address, returning empty set"
return @[]
method resolveTxt*(self: DnsResolver, address: string): Future[seq[string]] {.async.} =
method resolveTxt*(
self: DnsResolver,
address: string): Future[seq[string]] {.async.} =
trace "Resolving TXT using DNS", address, servers = self.nameServers.mapIt($it)
for _ in 0 ..< self.nameServers.len:
let server = self.nameServers[0]
try:
# toString can has a raises: [Exception, ..] because of
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
let response = await getDnsResponse(server, address, TXT)
return exceptionToAssert:
trace "Got TXT response",
server = $server, answer = response.answers.mapIt(it.toString())
response.answers.mapIt(it.toString())
trace "Got TXT response", server = $server, answer=response.answers.mapIt(it.toString())
return response.answers.mapIt(it.toString())
except CancelledError as e:
raise e
except CatchableError as e:
info "Failed to query DNS", address, error = e.msg
info "Failed to query DNS", address, error=e.msg
self.nameServers.add(self.nameServers[0])
self.nameServers.delete(0)
continue
except Exception as e:
# toString can has a raises: [Exception, ..] because of
# https://github.com/nim-lang/Nim/commit/035134de429b5d99c5607c5fae912762bebb6008
# it can't actually raise though
raiseAssert e.msg
debug "Failed to resolve TXT, returning empty set"
return @[]
proc new*(T: typedesc[DnsResolver], nameServers: seq[TransportAddress]): T =
proc new*(
T: typedesc[DnsResolver],
nameServers: seq[TransportAddress]): T =
T(nameServers: nameServers)

View File

@@ -7,9 +7,14 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/tables, chronos, chronicles
import
std/[streams, strutils, tables],
chronos, chronicles
import nameresolver
@@ -24,8 +29,10 @@ type MockResolver* = ref object of NameResolver
ipResponses*: Table[(string, bool), seq[string]]
method resolveIp*(
self: MockResolver, address: string, port: Port, domain: Domain = Domain.AF_UNSPEC
): Future[seq[TransportAddress]] {.async.} =
self: MockResolver,
address: string,
port: Port,
domain: Domain = Domain.AF_UNSPEC): Future[seq[TransportAddress]] {.async.} =
if domain == Domain.AF_INET or domain == Domain.AF_UNSPEC:
for resp in self.ipResponses.getOrDefault((address, false)):
result.add(initTAddress(resp, port))
@@ -34,8 +41,9 @@ method resolveIp*(
for resp in self.ipResponses.getOrDefault((address, true)):
result.add(initTAddress(resp, port))
method resolveTxt*(self: MockResolver, address: string): Future[seq[string]] {.async.} =
method resolveTxt*(
self: MockResolver,
address: string): Future[seq[string]] {.async.} =
return self.txtResponses.getOrDefault(address)
proc new*(T: typedesc[MockResolver]): T =
T()
proc new*(T: typedesc[MockResolver]): T = T()

View File

@@ -7,28 +7,37 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[sugar, sets, sequtils, strutils]
import chronos, chronicles, stew/endians2
import
chronos,
chronicles,
stew/[endians2, byteutils]
import ".."/[multiaddress, multicodec]
logScope:
topics = "libp2p nameresolver"
type NameResolver* = ref object of RootObj
type
NameResolver* = ref object of RootObj
method resolveTxt*(
self: NameResolver, address: string
): Future[seq[string]] {.async, base.} =
self: NameResolver,
address: string): Future[seq[string]] {.async, base.} =
## Get TXT record
##
doAssert(false, "Not implemented!")
method resolveIp*(
self: NameResolver, address: string, port: Port, domain: Domain = Domain.AF_UNSPEC
): Future[seq[TransportAddress]] {.async, base.} =
self: NameResolver,
address: string,
port: Port,
domain: Domain = Domain.AF_UNSPEC): Future[seq[TransportAddress]] {.async, base.} =
## Resolve the specified address
##
@@ -36,17 +45,17 @@ method resolveIp*(
proc getHostname*(ma: MultiAddress): string =
let
firstPart = ma[0].valueOr:
return ""
firstPart = ma[0].valueOr: return ""
fpSplitted = ($firstPart).split('/', 2)
if fpSplitted.len > 2:
fpSplitted[2]
else:
""
if fpSplitted.len > 2: fpSplitted[2]
else: ""
proc resolveOneAddress(
self: NameResolver, ma: MultiAddress, domain: Domain = Domain.AF_UNSPEC, prefix = ""
): Future[seq[MultiAddress]] {.async.} =
self: NameResolver,
ma: MultiAddress,
domain: Domain = Domain.AF_UNSPEC,
prefix = ""): Future[seq[MultiAddress]]
{.async, raises: [Defect, MaError, TransportAddressError].} =
#Resolve a single address
var pbuf: array[2, byte]
@@ -62,14 +71,15 @@ proc resolveOneAddress(
for address in resolvedAddresses:
var createdAddress = MultiAddress.init(address).tryGet()[0].tryGet()
for part in ma:
if DNS.match(part.tryGet()):
continue
if DNS.match(part.tryGet()): continue
createdAddress &= part.tryGet()
createdAddress
proc resolveDnsAddr*(
self: NameResolver, ma: MultiAddress, depth: int = 0
): Future[seq[MultiAddress]] {.async.} =
self: NameResolver,
ma: MultiAddress,
depth: int = 0): Future[seq[MultiAddress]] {.async.} =
if not DNSADDR.matchPartial(ma):
return @[ma]
@@ -86,12 +96,10 @@ proc resolveDnsAddr*(
var result: seq[MultiAddress]
for entry in txt:
if not entry.startsWith("dnsaddr="):
continue
let entryValue = MultiAddress.init(entry[8 ..^ 1]).tryGet()
if not entry.startsWith("dnsaddr="): continue
let entryValue = MultiAddress.init(entry[8..^1]).tryGet()
if entryValue.contains(multiCodec("p2p")).tryGet() and
ma.contains(multiCodec("p2p")).tryGet():
if entryValue.contains(multiCodec("p2p")).tryGet() and ma.contains(multiCodec("p2p")).tryGet():
if entryValue[multiCodec("p2p")] != ma[multiCodec("p2p")]:
continue
@@ -104,17 +112,17 @@ proc resolveDnsAddr*(
return @[]
return result
proc resolveMAddress*(
self: NameResolver, address: MultiAddress
): Future[seq[MultiAddress]] {.async.} =
self: NameResolver,
address: MultiAddress): Future[seq[MultiAddress]] {.async.} =
var res = initOrderedSet[MultiAddress]()
if not DNS.matchPartial(address):
res.incl(address)
else:
let code = address[0].tryGet().protoCode().tryGet()
let seq =
case code
let code = address[0].get().protoCode().get()
let seq = case code:
of multiCodec("dns"):
await self.resolveOneAddress(address)
of multiCodec("dns4"):
@@ -124,7 +132,7 @@ proc resolveMAddress*(
of multiCodec("dnsaddr"):
await self.resolveDnsAddr(address)
else:
assert false
doAssert false
@[address]
for ad in seq:
res.incl(ad)

View File

@@ -1,92 +0,0 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import std/[sequtils, tables, sugar]
import chronos
import multiaddress, multicodec
type
## Manages observed MultiAddresses by reomte peers. It keeps track of the most observed IP and IP/Port.
ObservedAddrManager* = ref object of RootObj
observedIPsAndPorts: seq[MultiAddress]
maxSize: int
minCount: int
proc addObservation*(self: ObservedAddrManager, observedAddr: MultiAddress): bool =
## Adds a new observed MultiAddress. If the number of observations exceeds maxSize, the oldest one is removed.
if self.observedIPsAndPorts.len >= self.maxSize:
self.observedIPsAndPorts.del(0)
self.observedIPsAndPorts.add(observedAddr)
return true
proc getProtocol(
self: ObservedAddrManager, observations: seq[MultiAddress], multiCodec: MultiCodec
): Opt[MultiAddress] =
var countTable = toCountTable(observations)
countTable.sort()
var orderedPairs = toSeq(countTable.pairs)
for (ma, count) in orderedPairs:
let protoCode = (ma[0].flatMap(protoCode)).valueOr:
continue
if protoCode == multiCodec and count >= self.minCount:
return Opt.some(ma)
return Opt.none(MultiAddress)
proc getMostObservedProtocol(
self: ObservedAddrManager, multiCodec: MultiCodec
): Opt[MultiAddress] =
## Returns the most observed IP address or none if the number of observations are less than minCount.
let observedIPs = collect:
for observedIp in self.observedIPsAndPorts:
observedIp[0].valueOr:
continue
return self.getProtocol(observedIPs, multiCodec)
proc getMostObservedProtoAndPort(
self: ObservedAddrManager, multiCodec: MultiCodec
): Opt[MultiAddress] =
## Returns the most observed IP/Port address or none if the number of observations are less than minCount.
return self.getProtocol(self.observedIPsAndPorts, multiCodec)
proc getMostObservedProtosAndPorts*(self: ObservedAddrManager): seq[MultiAddress] =
## Returns the most observed IP4/Port and IP6/Port address or an empty seq if the number of observations
## are less than minCount.
var res: seq[MultiAddress]
self.getMostObservedProtoAndPort(multiCodec("ip4")).withValue(ip4):
res.add(ip4)
self.getMostObservedProtoAndPort(multiCodec("ip6")).withValue(ip6):
res.add(ip6)
return res
proc guessDialableAddr*(self: ObservedAddrManager, ma: MultiAddress): MultiAddress =
## Replaces the first proto value of each listen address by the corresponding (matching the proto code) most observed value.
## If the most observed value is not available, the original MultiAddress is returned.
let
maFirst = ma[0].valueOr:
return ma
maRest = ma[1 ..^ 1].valueOr:
return ma
maFirstProto = maFirst.protoCode().valueOr:
return ma
let observedIP = self.getMostObservedProtocol(maFirstProto).valueOr:
return ma
return concat(observedIP, maRest).valueOr:
ma
proc `$`*(self: ObservedAddrManager): string =
## Returns a string representation of the ObservedAddrManager.
return "IPs and Ports: " & $self.observedIPsAndPorts
proc new*(T: typedesc[ObservedAddrManager], maxSize = 10, minCount = 3): T =
## Creates a new ObservedAddrManager.
return
T(observedIPsAndPorts: newSeq[MultiAddress](), maxSize: maxSize, minCount: minCount)

View File

@@ -9,7 +9,10 @@
## This module implementes API for libp2p peer.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
{.push public.}
import
@@ -18,18 +21,17 @@ import
chronicles,
nimcrypto/utils,
utility,
./crypto/crypto,
./multicodec,
./multihash,
./vbuffer,
./crypto/crypto, ./multicodec, ./multihash, ./vbuffer,
./protobuf/minprotobuf
export results, utility
const maxInlineKeyLength* = 42
const
maxInlineKeyLength* = 42
type PeerId* = object
data*: seq[byte]
type
PeerId* = object
data*: seq[byte]
func `$`*(pid: PeerId): string =
## Return base58 encoded ``pid`` representation.
@@ -42,12 +44,14 @@ func shortLog*(pid: PeerId): string =
if len(spid) > 10:
spid[3] = '*'
spid.delete(4 .. spid.high - 6)
when (NimMajor, NimMinor) > (1, 4):
spid.delete(4 .. spid.high - 6)
else:
spid.delete(4, spid.high - 6)
spid
chronicles.formatIt(PeerId):
shortLog(it)
chronicles.formatIt(PeerId): shortLog(it)
func toBytes*(pid: PeerId, data: var openArray[byte]): int =
## Store PeerId ``pid`` to array of bytes ``data``.
@@ -80,8 +84,7 @@ func cmp*(a, b: PeerId): int =
var m = min(len(a.data), len(b.data))
while i < m:
result = ord(a.data[i]) - ord(b.data[i])
if result != 0:
return
if result != 0: return
inc(i)
result = len(a.data) - len(b.data)
@@ -168,18 +171,18 @@ func init*(t: typedesc[PeerId], data: string): Result[PeerId, cstring] =
func init*(t: typedesc[PeerId], pubkey: PublicKey): Result[PeerId, cstring] =
## Create new peer id from public key ``pubkey``.
var pubraw =
?pubkey.getBytes().orError(cstring("peerid: failed to get bytes from given key"))
var pubraw = ? pubkey.getBytes().orError(
cstring("peerid: failed to get bytes from given key"))
var mh: MultiHash
if len(pubraw) <= maxInlineKeyLength:
mh = ?MultiHash.digest("identity", pubraw)
mh = ? MultiHash.digest("identity", pubraw)
else:
mh = ?MultiHash.digest("sha2-256", pubraw)
mh = ? MultiHash.digest("sha2-256", pubraw)
ok(PeerId(data: mh.data.buffer))
func init*(t: typedesc[PeerId], seckey: PrivateKey): Result[PeerId, cstring] =
## Create new peer id from private key ``seckey``.
PeerId.init(?seckey.getPublicKey().orError(cstring("invalid private key")))
PeerId.init(? seckey.getPublicKey().orError(cstring("invalid private key")))
proc random*(t: typedesc[PeerId], rng = newRng()): Result[PeerId, cstring] =
## Create new peer id with random public key.
@@ -188,11 +191,19 @@ proc random*(t: typedesc[PeerId], rng = newRng()): Result[PeerId, cstring] =
func match*(pid: PeerId, pubkey: PublicKey): bool =
## Returns ``true`` if ``pid`` matches public key ``pubkey``.
PeerId.init(pubkey) == Result[PeerId, cstring].ok(pid)
let p = PeerId.init(pubkey)
if p.isErr:
false
else:
pid == p.get()
func match*(pid: PeerId, seckey: PrivateKey): bool =
## Returns ``true`` if ``pid`` matches private key ``seckey``.
PeerId.init(seckey) == Result[PeerId, cstring].ok(pid)
let p = PeerId.init(seckey)
if p.isErr:
false
else:
pid == p.get()
## Serialization/Deserialization helpers
@@ -204,13 +215,12 @@ func write*(pb: var ProtoBuffer, field: int, pid: PeerId) =
## Write PeerId value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
write(pb, field, pid.data)
func getField*(
pb: ProtoBuffer, field: int, pid: var PeerId
): ProtoResult[bool] {.inline.} =
func getField*(pb: ProtoBuffer, field: int,
pid: var PeerId): ProtoResult[bool] {.inline.} =
## Read ``PeerId`` from ProtoBuf's message and validate it
var buffer: seq[byte]
let res = ?pb.getField(field, buffer)
if not (res):
let res = ? pb.getField(field, buffer)
if not(res):
ok(false)
else:
var peerId: PeerId

View File

@@ -7,10 +7,13 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
{.push public.}
import std/sequtils
import std/[options, sequtils]
import pkg/[chronos, chronicles, stew/results]
import peerid, multiaddress, multicodec, crypto/crypto, routing_record, errors, utility
@@ -21,18 +24,15 @@ export peerid, multiaddress, crypto, routing_record, errors, results
type
PeerInfoError* = object of LPError
AddressMapper* = proc(listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]] {.
gcsafe, raises: []
.} ## A proc that expected to resolve the listen addresses into dialable addresses
AddressMapper* =
proc(listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]]
{.gcsafe, raises: [Defect].}
PeerInfo* {.public.} = ref object
peerId*: PeerId
listenAddrs*: seq[MultiAddress]
## contains addresses the node listens on, which may include wildcard and private addresses (not directly reachable).
addrs*: seq[MultiAddress]
## contains resolved addresses that other peers can use to connect, including public-facing NAT and port-forwarded addresses.
addrs: seq[MultiAddress]
addressMappers*: seq[AddressMapper]
## contains a list of procs that can be used to resolve the listen addresses into dialable addresses.
protocols*: seq[string]
protoVersion*: string
agentVersion*: string
@@ -49,57 +49,59 @@ func shortLog*(p: PeerInfo): auto =
protoVersion: p.protoVersion,
agentVersion: p.agentVersion,
)
chronicles.formatIt(PeerInfo):
shortLog(it)
chronicles.formatIt(PeerInfo): shortLog(it)
proc update*(p: PeerInfo) {.async.} =
# p.addrs.len == 0 overrides addrs only if it is the first time update is being executed or if the field is empty.
# p.addressMappers.len == 0 is for when all addressMappers have been removed,
# and we wish to have addrs in its initial state, i.e., a copy of listenAddrs.
if p.addrs.len == 0 or p.addressMappers.len == 0:
p.addrs = p.listenAddrs
p.addrs = p.listenAddrs
for mapper in p.addressMappers:
p.addrs = await mapper(p.addrs)
p.signedPeerRecord = SignedPeerRecord.init(
p.privateKey, PeerRecord.init(p.peerId, p.addrs)
).valueOr:
info "Can't update the signed peer record"
return
let sprRes = SignedPeerRecord.init(
p.privateKey,
PeerRecord.init(p.peerId, p.addrs)
)
if sprRes.isOk:
p.signedPeerRecord = sprRes.get()
else:
discard
#info "Can't update the signed peer record"
proc addrs*(p: PeerInfo): seq[MultiAddress] =
p.addrs
proc fullAddrs*(p: PeerInfo): MaResult[seq[MultiAddress]] =
let peerIdPart = ?MultiAddress.init(multiCodec("p2p"), p.peerId.data)
let peerIdPart = ? MultiAddress.init(multiCodec("p2p"), p.peerId.data)
var res: seq[MultiAddress]
for address in p.addrs:
res.add(?concat(address, peerIdPart))
res.add(? concat(address, peerIdPart))
ok(res)
proc parseFullAddress*(ma: MultiAddress): MaResult[(PeerId, MultiAddress)] =
let p2pPart = ?ma[^1]
if ?p2pPart.protoCode != multiCodec("p2p"):
let p2pPart = ? ma[^1]
if ? p2pPart.protoCode != multiCodec("p2p"):
return err("Missing p2p part from multiaddress!")
let res =
(?PeerId.init(?p2pPart.protoArgument()).orErr("invalid peerid"), ?ma[0 .. ^2])
let res = (
? PeerId.init(? p2pPart.protoArgument()).orErr("invalid peerid"),
? ma[0 .. ^2]
)
ok(res)
proc parseFullAddress*(ma: string | seq[byte]): MaResult[(PeerId, MultiAddress)] =
parseFullAddress(?MultiAddress.init(ma))
parseFullAddress(? MultiAddress.init(ma))
proc new*(
p: typedesc[PeerInfo],
key: PrivateKey,
listenAddrs: openArray[MultiAddress] = [],
protocols: openArray[string] = [],
protoVersion: string = "",
agentVersion: string = "",
addressMappers = newSeq[AddressMapper](),
): PeerInfo {.raises: [LPError].} =
let pubkey =
try:
p: typedesc[PeerInfo],
key: PrivateKey,
listenAddrs: openArray[MultiAddress] = [],
protocols: openArray[string] = [],
protoVersion: string = "",
agentVersion: string = "",
addressMappers = newSeq[AddressMapper](),
): PeerInfo
{.raises: [Defect, LPError].} =
let pubkey = try:
key.getPublicKey().tryGet()
except CatchableError:
raise newException(PeerInfoError, "invalid private key")
@@ -114,7 +116,7 @@ proc new*(
agentVersion: agentVersion,
listenAddrs: @listenAddrs,
protocols: @protocols,
addressMappers: addressMappers,
addressMappers: addressMappers
)
return peerInfo

View File

@@ -16,33 +16,31 @@ runnableExamples:
# Create a custom book type
type MoodBook = ref object of PeerBook[string]
var somePeerId = PeerId.random().expect("get random key")
var somePeerId = PeerId.random().get()
peerStore[MoodBook][somePeerId] = "Happy"
doAssert peerStore[MoodBook][somePeerId] == "Happy"
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import
std/[tables, sets, options, macros],
chronos,
./crypto/crypto,
./protocols/identify,
./protocols/protocol,
./peerid,
./peerinfo,
./peerid, ./peerinfo,
./routing_record,
./multiaddress,
./stream/connection,
./multistream,
./muxers/muxer,
utility
type
#################
# Handler types #
#################
PeerBookChangeHandler* = proc(peerId: PeerId) {.gcsafe, raises: [].}
PeerBookChangeHandler* = proc(peerId: PeerId) {.gcsafe, raises: [Defect].}
#########
# Books #
@@ -69,26 +67,27 @@ type
####################
# Peer store types #
####################
PeerStore* {.public.} = ref object
books: Table[string, BasePeerBook]
identify: Identify
capacity*: int
toClean*: seq[PeerId]
proc new*(
T: type PeerStore, identify: Identify, capacity = 1000
): PeerStore {.public.} =
T(identify: identify, capacity: capacity)
proc new*(T: type PeerStore, capacity = 1000): PeerStore {.public.} =
T(capacity: capacity)
#########################
# Generic Peer Book API #
#########################
proc `[]`*[T](peerBook: PeerBook[T], peerId: PeerId): T {.public.} =
proc `[]`*[T](peerBook: PeerBook[T],
peerId: PeerId): T {.public.} =
## Get all known metadata of a provided peer, or default(T) if missing
peerBook.book.getOrDefault(peerId)
proc `[]=`*[T](peerBook: PeerBook[T], peerId: PeerId, entry: T) {.public.} =
proc `[]=`*[T](peerBook: PeerBook[T],
peerId: PeerId,
entry: T) {.public.} =
## Set metadata for a given peerId.
peerBook.book[peerId] = entry
@@ -97,7 +96,8 @@ proc `[]=`*[T](peerBook: PeerBook[T], peerId: PeerId, entry: T) {.public.} =
for handler in peerBook.changeHandlers:
handler(peerId)
proc del*[T](peerBook: PeerBook[T], peerId: PeerId): bool {.public.} =
proc del*[T](peerBook: PeerBook[T],
peerId: PeerId): bool {.public.} =
## Delete the provided peer from the book. Returns whether the peer was in the book
if peerId notin peerBook.book:
@@ -116,8 +116,7 @@ proc addHandler*[T](peerBook: PeerBook[T], handler: PeerBookChangeHandler) {.pub
## Adds a callback that will be called everytime the book changes
peerBook.changeHandlers.add(handler)
proc len*[T](peerBook: PeerBook[T]): int {.public.} =
peerBook.book.len
proc len*[T](peerBook: PeerBook[T]): int {.public.} = peerBook.book.len
##################
# Peer Store API #
@@ -140,35 +139,42 @@ proc `[]`*[T](p: PeerStore, typ: type[T]): T {.public.} =
p.books[name] = result
return result
proc del*(peerStore: PeerStore, peerId: PeerId) {.public.} =
proc del*(peerStore: PeerStore,
peerId: PeerId) {.public.} =
## Delete the provided peer from every book.
for _, book in peerStore.books:
book.deletor(peerId)
proc updatePeerInfo*(peerStore: PeerStore, info: IdentifyInfo) =
proc updatePeerInfo*(
peerStore: PeerStore,
info: IdentifyInfo) =
if info.addrs.len > 0:
peerStore[AddressBook][info.peerId] = info.addrs
info.pubkey.withValue(pubkey):
peerStore[KeyBook][info.peerId] = pubkey
if info.pubkey.isSome:
peerStore[KeyBook][info.peerId] = info.pubkey.get()
info.agentVersion.withValue(agentVersion):
peerStore[AgentBook][info.peerId] = agentVersion.string
if info.agentVersion.isSome:
peerStore[AgentBook][info.peerId] = info.agentVersion.get().string
info.protoVersion.withValue(protoVersion):
peerStore[ProtoVersionBook][info.peerId] = protoVersion.string
if info.protoVersion.isSome:
peerStore[ProtoVersionBook][info.peerId] = info.protoVersion.get().string
if info.protos.len > 0:
peerStore[ProtoBook][info.peerId] = info.protos
info.signedPeerRecord.withValue(signedPeerRecord):
peerStore[SPRBook][info.peerId] = signedPeerRecord
if info.signedPeerRecord.isSome:
peerStore[SPRBook][info.peerId] = info.signedPeerRecord.get()
let cleanupPos = peerStore.toClean.find(info.peerId)
if cleanupPos >= 0:
peerStore.toClean.delete(cleanupPos)
proc cleanup*(peerStore: PeerStore, peerId: PeerId) =
proc cleanup*(
peerStore: PeerStore,
peerId: PeerId) =
if peerStore.capacity == 0:
peerStore.del(peerId)
return
@@ -180,32 +186,3 @@ proc cleanup*(peerStore: PeerStore, peerId: PeerId) =
while peerStore.toClean.len > peerStore.capacity:
peerStore.del(peerStore.toClean[0])
peerStore.toClean.delete(0)
proc identify*(peerStore: PeerStore, muxer: Muxer) {.async.} =
# new stream for identify
var stream = await muxer.newStream()
if stream == nil:
return
try:
if (await MultistreamSelect.select(stream, peerStore.identify.codec())):
let info = await peerStore.identify.identify(stream, stream.peerId)
when defined(libp2p_agents_metrics):
var
knownAgent = "unknown"
shortAgent =
info.agentVersion.get("").split("/")[0].safeToLowerAscii().get("")
if KnownLibP2PAgentsSeq.contains(shortAgent):
knownAgent = shortAgent
muxer.connection.setShortAgent(knownAgent)
peerStore.updatePeerInfo(info)
finally:
await stream.closeWithEOF()
proc getMostObservedProtosAndPorts*(self: PeerStore): seq[MultiAddress] =
return self.identify.observedAddrManager.getMostObservedProtosAndPorts()
proc guessDialableAddr*(self: PeerStore, ma: MultiAddress): MultiAddress =
return self.identify.observedAddrManager.guessDialableAddr(ma)

View File

@@ -9,7 +9,10 @@
## This module implements minimal Google's ProtoBuf primitives.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import ../varint, ../utility, stew/[endians2, results]
export results, utility
@@ -21,20 +24,14 @@ const MaxMessageSize = 1'u shl 22
type
ProtoFieldKind* = enum
## Protobuf's field types enum
Varint
Fixed64
Length
StartGroup
EndGroup
Fixed32
Varint, Fixed64, Length, StartGroup, EndGroup, Fixed32
ProtoFlags* = enum
## Protobuf's encoding types
WithVarintLength
WithUint32BeLength
WithUint32LeLength
WithVarintLength, WithUint32BeLength, WithUint32LeLength
ProtoBuffer* = object ## Protobuf's message representation object
ProtoBuffer* = object
## Protobuf's message representation object
options: set[ProtoFlags]
buffer*: seq[byte]
offset*: int
@@ -45,7 +42,8 @@ type
wire*: ProtoFieldKind
index*: uint64
ProtoField* = object ## Protobuf's message field representation object
ProtoField* = object
## Protobuf's message field representation object
index*: int
case kind*: ProtoFieldKind
of Varint:
@@ -60,33 +58,30 @@ type
discard
ProtoError* {.pure.} = enum
VarintDecode
MessageIncomplete
BufferOverflow
MessageTooBig
BadWireType
IncorrectBlob
VarintDecode,
MessageIncomplete,
BufferOverflow,
MessageTooBig,
BadWireType,
IncorrectBlob,
RequiredFieldMissing
ProtoResult*[T] = Result[T, ProtoError]
ProtoScalar* =
uint | uint32 | uint64 | zint | zint32 | zint64 | hint | hint32 | hint64 | float32 |
float64
ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
hint | hint32 | hint64 | float32 | float64
const SupportedWireTypes* =
@[
uint64(ProtoFieldKind.Varint),
uint64(ProtoFieldKind.Fixed64),
uint64(ProtoFieldKind.Length),
uint64(ProtoFieldKind.Fixed32),
]
const
SupportedWireTypes* = {
int(ProtoFieldKind.Varint),
int(ProtoFieldKind.Fixed64),
int(ProtoFieldKind.Length),
int(ProtoFieldKind.Fixed32)
}
template checkFieldNumber*(i: int) =
doAssert(
(i > 0 and i < (1 shl 29)) and not (i >= 19000 and i <= 19999),
"Incorrect or reserved field number",
)
doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999),
"Incorrect or reserved field number")
template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 =
## Get protobuf's field header integer for ``index`` and ``wire``.
@@ -122,34 +117,29 @@ proc vsizeof*(field: ProtoField): int {.inline.} =
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32)
of ProtoFieldKind.Length:
vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) +
len(field.vbuffer)
len(field.vbuffer)
else:
0
proc initProtoBuffer*(
data: seq[byte], offset = 0, options: set[ProtoFlags] = {}, maxSize = MaxMessageSize
): ProtoBuffer =
proc initProtoBuffer*(data: seq[byte], offset = 0,
options: set[ProtoFlags] = {},
maxSize = MaxMessageSize): ProtoBuffer =
## Initialize ProtoBuffer with shallow copy of ``data``.
result.buffer = data
result.offset = offset
result.options = options
result.maxSize = maxSize
proc initProtoBuffer*(
data: openArray[byte],
offset = 0,
options: set[ProtoFlags] = {},
maxSize = MaxMessageSize,
): ProtoBuffer =
proc initProtoBuffer*(data: openArray[byte], offset = 0,
options: set[ProtoFlags] = {},
maxSize = MaxMessageSize): ProtoBuffer =
## Initialize ProtoBuffer with copy of ``data``.
result.buffer = @data
result.offset = offset
result.options = options
result.maxSize = maxSize
proc initProtoBuffer*(
options: set[ProtoFlags] = {}, maxSize = MaxMessageSize
): ProtoBuffer =
proc initProtoBuffer*(options: set[ProtoFlags] = {}, maxSize = MaxMessageSize): ProtoBuffer =
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
result.buffer = newSeq[byte]()
result.options = options
@@ -165,31 +155,37 @@ proc initProtoBuffer*(
result.buffer.setLen(4)
result.offset = 4
proc write*[T: ProtoScalar](pb: var ProtoBuffer, field: int, value: T) =
proc write*[T: ProtoScalar](pb: var ProtoBuffer,
field: int, value: T) =
checkFieldNumber(field)
var length = 0
when (T is uint64) or (T is uint32) or (T is uint) or (T is zint64) or (T is zint32) or
(T is zint) or (T is hint64) or (T is hint32) or (T is hint):
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) + vsizeof(value)
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) +
vsizeof(value)
let header = ProtoFieldKind.Varint
elif T is float32:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) + sizeof(T)
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) +
sizeof(T)
let header = ProtoFieldKind.Fixed32
elif T is float64:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) + sizeof(T)
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) +
sizeof(T)
let header = ProtoFieldKind.Fixed64
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field, header))
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, header))
doAssert(hres.isOk())
pb.offset += length
when (T is uint64) or (T is uint32) or (T is uint):
let vres = PB.putUVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or (T is hint64) or (T is hint32) or
(T is hint):
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let vres = putSVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
@@ -204,14 +200,14 @@ proc write*[T: ProtoScalar](pb: var ProtoBuffer, field: int, value: T) =
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc writePacked*[T: ProtoScalar](
pb: var ProtoBuffer, field: int, value: openArray[T]
) =
proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int,
value: openArray[T]) =
checkFieldNumber(field)
var length = 0
let dlength =
when (T is uint64) or (T is uint32) or (T is uint) or (T is zint64) or (T is zint32) or
(T is zint) or (T is hint64) or (T is hint32) or (T is hint):
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
var res = 0
for item in value:
res += vsizeof(item)
@@ -235,8 +231,8 @@ proc writePacked*[T: ProtoScalar](
let vres = PB.putUVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or (T is hint64) or (T is hint32) or
(T is hint):
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
length = 0
let vres = PB.putSVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
@@ -252,19 +248,19 @@ proc writePacked*[T: ProtoScalar](
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc write*[T: byte | char](pb: var ProtoBuffer, field: int, value: openArray[T]) =
proc write*[T: byte|char](pb: var ProtoBuffer, field: int,
value: openArray[T]) =
checkFieldNumber(field)
var length = 0
let flength =
vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) + vsizeof(uint64(len(value))) +
len(value)
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) +
vsizeof(uint64(len(value))) + len(value)
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(
pb.toOpenArray(), length, getProtoHeader(field, ProtoFieldKind.Length)
)
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, ProtoFieldKind.Length))
doAssert(hres.isOk())
pb.offset += length
let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(len(value)))
let lres = PB.putUVarint(pb.toOpenArray(), length,
uint64(len(value)))
doAssert(lres.isOk())
pb.offset += length
if len(value) > 0:
@@ -301,7 +297,8 @@ proc finish*(pb: var ProtoBuffer) =
doAssert(len(pb.buffer) > 0)
pb.offset = 0
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): ProtoResult[void] =
proc getHeader(data: var ProtoBuffer,
header: var ProtoHeader): ProtoResult[void] =
var length = 0
var hdr = 0'u64
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
@@ -356,9 +353,9 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
err(ProtoError.BadWireType)
proc getValue[T: ProtoScalar](
data: var ProtoBuffer, header: ProtoHeader, outval: var T
): ProtoResult[void] =
proc getValue[T: ProtoScalar](data: var ProtoBuffer,
header: ProtoHeader,
outval: var T): ProtoResult[void] =
when (T is uint64) or (T is uint32) or (T is uint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
@@ -369,8 +366,8 @@ proc getValue[T: ProtoScalar](
ok()
else:
err(ProtoError.VarintDecode)
elif (T is zint64) or (T is zint32) or (T is zint) or (T is hint64) or (T is hint32) or
(T is hint):
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
var value = T(0)
@@ -397,12 +394,9 @@ proc getValue[T: ProtoScalar](
else:
err(ProtoError.MessageIncomplete)
proc getValue[T: byte | char](
data: var ProtoBuffer,
header: ProtoHeader,
outBytes: var openArray[T],
outLength: var int,
): ProtoResult[void] =
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var openArray[T],
outLength: var int): ProtoResult[void] =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
@@ -429,9 +423,8 @@ proc getValue[T: byte | char](
else:
err(ProtoError.VarintDecode)
proc getValue[T: seq[byte] | string](
data: var ProtoBuffer, header: ProtoHeader, outBytes: var T
): ProtoResult[void] =
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var T): ProtoResult[void] =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
@@ -453,20 +446,20 @@ proc getValue[T: seq[byte] | string](
else:
err(ProtoError.VarintDecode)
proc getField*[T: ProtoScalar](
data: ProtoBuffer, field: int, output: var T
): ProtoResult[bool] =
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var T): ProtoResult[bool] =
checkFieldNumber(field)
var current: T
var res = false
var pb = data
while not (pb.isEmpty()):
while not(pb.isEmpty()):
var header: ProtoHeader
?pb.getHeader(header)
? pb.getHeader(header)
let wireCheck =
when (T is uint64) or (T is uint32) or (T is uint) or (T is zint64) or
(T is zint32) or (T is zint) or (T is hint64) or (T is hint32) or (T is hint):
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
header.wire == ProtoFieldKind.Varint
elif T is float32:
header.wire == ProtoFieldKind.Fixed32
@@ -484,9 +477,9 @@ proc getField*[T: ProtoScalar](
else:
# We are ignoring wire types different from what we expect, because it
# is how `protoc` is working.
?pb.skipValue(header)
? pb.skipValue(header)
else:
?pb.skipValue(header)
? pb.skipValue(header)
if res:
output = current
@@ -494,16 +487,16 @@ proc getField*[T: ProtoScalar](
else:
ok(false)
proc getField*[T: byte | char](
data: ProtoBuffer, field: int, output: var openArray[T], outlen: var int
): ProtoResult[bool] =
proc getField*[T: byte|char](data: ProtoBuffer, field: int,
output: var openArray[T],
outlen: var int): ProtoResult[bool] =
checkFieldNumber(field)
var pb = data
var res = false
outlen = 0
while not (pb.isEmpty()):
while not(pb.isEmpty()):
var header: ProtoHeader
let hres = pb.getHeader(header)
if hres.isErr():
@@ -546,14 +539,13 @@ proc getField*[T: byte | char](
else:
ok(false)
proc getField*[T: seq[byte] | string](
data: ProtoBuffer, field: int, output: var T
): ProtoResult[bool] =
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var T): ProtoResult[bool] =
checkFieldNumber(field)
var res = false
var pb = data
while not (pb.isEmpty()):
while not(pb.isEmpty()):
var header: ProtoHeader
let hres = pb.getHeader(header)
if hres.isErr():
@@ -584,32 +576,37 @@ proc getField*[T: seq[byte] | string](
else:
ok(false)
proc getField*(
pb: ProtoBuffer, field: int, output: var ProtoBuffer
): ProtoResult[bool] {.inline.} =
proc getField*(pb: ProtoBuffer, field: int,
output: var ProtoBuffer): ProtoResult[bool] {.inline.} =
var buffer: seq[byte]
if ?pb.getField(field, buffer):
output = initProtoBuffer(buffer)
ok(true)
let res = pb.getField(field, buffer)
if res.isOk():
if res.get():
output = initProtoBuffer(buffer)
ok(true)
else:
ok(false)
else:
ok(false)
err(res.error)
proc getRequiredField*[T](
pb: ProtoBuffer, field: int, output: var T
): ProtoResult[void] {.inline.} =
if ?pb.getField(field, output):
ok()
proc getRequiredField*[T](pb: ProtoBuffer, field: int,
output: var T): ProtoResult[void] {.inline.} =
let res = pb.getField(field, output)
if res.isOk():
if res.get():
ok()
else:
err(RequiredFieldMissing)
else:
err(RequiredFieldMissing)
err(res.error)
proc getRepeatedField*[T: seq[byte] | string](
data: ProtoBuffer, field: int, output: var seq[T]
): ProtoResult[bool] =
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var seq[T]): ProtoResult[bool] =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not (pb.isEmpty()):
while not(pb.isEmpty()):
var header: ProtoHeader
let hres = pb.getHeader(header)
if hres.isErr():
@@ -640,14 +637,13 @@ proc getRepeatedField*[T: seq[byte] | string](
else:
ok(false)
proc getRepeatedField*[T: ProtoScalar](
data: ProtoBuffer, field: int, output: var seq[T]
): ProtoResult[bool] =
proc getRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var seq[T]): ProtoResult[bool] =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not (pb.isEmpty()):
while not(pb.isEmpty()):
var header: ProtoHeader
let hres = pb.getHeader(header)
if hres.isErr():
@@ -655,8 +651,8 @@ proc getRepeatedField*[T: ProtoScalar](
return err(hres.error)
if header.index == uint64(field):
if header.wire in
{ProtoFieldKind.Varint, ProtoFieldKind.Fixed32, ProtoFieldKind.Fixed64}:
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
ProtoFieldKind.Fixed64}:
var item: T
let vres = getValue(pb, header, item)
if vres.isOk():
@@ -680,22 +676,24 @@ proc getRepeatedField*[T: ProtoScalar](
else:
ok(false)
proc getRequiredRepeatedField*[T](
pb: ProtoBuffer, field: int, output: var seq[T]
): ProtoResult[void] {.inline.} =
if ?pb.getRepeatedField(field, output):
ok()
proc getRequiredRepeatedField*[T](pb: ProtoBuffer, field: int,
output: var seq[T]): ProtoResult[void] {.inline.} =
let res = pb.getRepeatedField(field, output)
if res.isOk():
if res.get():
ok()
else:
err(RequiredFieldMissing)
else:
err(RequiredFieldMissing)
err(res.error)
proc getPackedRepeatedField*[T: ProtoScalar](
data: ProtoBuffer, field: int, output: var seq[T]
): ProtoResult[bool] =
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var seq[T]): ProtoResult[bool] =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not (pb.isEmpty()):
while not(pb.isEmpty()):
var header: ProtoHeader
let hres = pb.getHeader(header)
if hres.isErr():
@@ -709,15 +707,15 @@ proc getPackedRepeatedField*[T: ProtoScalar](
if ares.isOk():
var pbarr = initProtoBuffer(arritem)
let itemHeader =
when (T is uint64) or (T is uint32) or (T is uint) or (T is zint64) or
(T is zint32) or (T is zint) or (T is hint64) or (T is hint32) or
(T is hint):
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
ProtoHeader(wire: ProtoFieldKind.Varint)
elif T is float32:
ProtoHeader(wire: ProtoFieldKind.Fixed32)
elif T is float64:
ProtoHeader(wire: ProtoFieldKind.Fixed64)
while not (pbarr.isEmpty()):
while not(pbarr.isEmpty()):
var item: T
let vres = getValue(pbarr, itemHeader, item)
if vres.isOk():

View File

@@ -7,39 +7,46 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import stew/results
import std/[options, sets, sequtils]
import stew/[results, objects]
import chronos, chronicles
import ../../../switch, ../../../multiaddress, ../../../peerid
import ../../../switch,
../../../multiaddress,
../../../peerid
import core
export core
logScope:
topics = "libp2p autonat"
type AutonatClient* = ref object of RootObj
type
AutonatClient* = ref object of RootObj
proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.} =
let pb = AutonatDial(
peerInfo: Opt.some(AutonatPeerInfo(id: Opt.some(pid), addrs: addrs))
).encode()
let pb = AutonatDial(peerInfo: some(AutonatPeerInfo(
id: some(pid),
addrs: addrs
))).encode()
await conn.writeLp(pb.buffer)
method dialMe*(
self: AutonatClient,
switch: Switch,
pid: PeerId,
addrs: seq[MultiAddress] = newSeq[MultiAddress](),
): Future[MultiAddress] {.base, async.} =
proc getResponseOrRaise(
autonatMsg: Opt[AutonatMsg]
): AutonatDialResponse {.raises: [AutonatError].} =
autonatMsg.withValue(msg):
if msg.msgType == DialResponse:
msg.response.withValue(res):
if not (res.status == Ok and res.ma.isNone()):
return res
raise newException(AutonatError, "Unexpected response")
method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[MultiAddress] = newSeq[MultiAddress]()):
Future[MultiAddress] {.base, async.} =
proc getResponseOrRaise(autonatMsg: Option[AutonatMsg]): AutonatDialResponse {.raises: [UnpackError, AutonatError].} =
if autonatMsg.isNone() or
autonatMsg.get().msgType != DialResponse or
autonatMsg.get().response.isNone() or
(autonatMsg.get().response.get().status == Ok and
autonatMsg.get().response.get().ma.isNone()):
raise newException(AutonatError, "Unexpected response")
else:
autonatMsg.get().response.get()
let conn =
try:
@@ -48,32 +55,20 @@ method dialMe*(
else:
await switch.dial(pid, addrs, AutonatCodec)
except CatchableError as err:
raise
newException(AutonatError, "Unexpected error when dialling: " & err.msg, err)
raise newException(AutonatError, "Unexpected error when dialling: " & err.msg, err)
# To bypass maxConnectionsPerPeer
let incomingConnection = switch.connManager.expectConnection(pid, In)
if incomingConnection.failed() and
incomingConnection.error of AlreadyExpectingConnectionError:
raise newException(AutonatError, incomingConnection.error.msg)
let incomingConnection = switch.connManager.expectConnection(pid)
defer:
await conn.close()
incomingConnection.cancel()
# Safer to always try to cancel cause we aren't sure if the peer dialled us or not
if incomingConnection.completed():
await (await incomingConnection).connection.close()
incomingConnection.cancel() # Safer to always try to cancel cause we aren't sure if the peer dialled us or not
trace "sending Dial", addrs = switch.peerInfo.addrs
await conn.sendDial(switch.peerInfo.peerId, switch.peerInfo.addrs)
let response = getResponseOrRaise(AutonatMsg.decode(await conn.readLp(1024)))
return
case response.status
return case response.status:
of ResponseStatus.Ok:
response.ma.tryGet()
response.ma.get()
of ResponseStatus.DialError:
raise newException(
AutonatUnreachableError, "Peer could not dial us back: " & response.text.get("")
)
raise newException(AutonatUnreachableError, "Peer could not dial us back: " & response.text.get(""))
else:
raise newException(
AutonatError, "Bad status " & $response.status & " " & response.text.get("")
)
raise newException(AutonatError, "Bad status " & $response.status & " " & response.text.get(""))

View File

@@ -7,11 +7,17 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[options, sets, sequtils]
import stew/[results, objects]
import chronos, chronicles
import ../../../multiaddress, ../../../peerid, ../../../errors
import ../../../multiaddress,
../../../peerid,
../../../errors
logScope:
topics = "libp2p autonat"
@@ -36,31 +42,26 @@ type
InternalError = 300
AutonatPeerInfo* = object
id*: Opt[PeerId]
id*: Option[PeerId]
addrs*: seq[MultiAddress]
AutonatDial* = object
peerInfo*: Opt[AutonatPeerInfo]
peerInfo*: Option[AutonatPeerInfo]
AutonatDialResponse* = object
status*: ResponseStatus
text*: Opt[string]
ma*: Opt[MultiAddress]
text*: Option[string]
ma*: Option[MultiAddress]
AutonatMsg* = object
msgType*: MsgType
dial*: Opt[AutonatDial]
response*: Opt[AutonatDialResponse]
NetworkReachability* {.pure.} = enum
Unknown
NotReachable
Reachable
dial*: Option[AutonatDial]
response*: Option[AutonatDialResponse]
proc encode(p: AutonatPeerInfo): ProtoBuffer =
result = initProtoBuffer()
p.id.withValue(id):
result.write(1, id)
if p.id.isSome():
result.write(1, p.id.get())
for ma in p.addrs:
result.write(2, ma.data.buffer)
result.finish()
@@ -69,8 +70,8 @@ proc encode*(d: AutonatDial): ProtoBuffer =
result = initProtoBuffer()
result.write(1, MsgType.Dial.uint)
var dial = initProtoBuffer()
d.peerInfo.withValue(pinfo):
dial.write(1, encode(pinfo))
if d.peerInfo.isSome():
dial.write(1, encode(d.peerInfo.get()))
dial.finish()
result.write(2, dial.buffer)
result.finish()
@@ -80,62 +81,72 @@ proc encode*(r: AutonatDialResponse): ProtoBuffer =
result.write(1, MsgType.DialResponse.uint)
var bufferResponse = initProtoBuffer()
bufferResponse.write(1, r.status.uint)
r.text.withValue(text):
bufferResponse.write(2, text)
r.ma.withValue(ma):
bufferResponse.write(3, ma)
if r.text.isSome():
bufferResponse.write(2, r.text.get())
if r.ma.isSome():
bufferResponse.write(3, r.ma.get())
bufferResponse.finish()
result.write(3, bufferResponse.buffer)
result.finish()
proc encode*(msg: AutonatMsg): ProtoBuffer =
msg.dial.withValue(dial):
return encode(dial)
msg.response.withValue(res):
return encode(res)
if msg.dial.isSome():
return encode(msg.dial.get())
if msg.response.isSome():
return encode(msg.response.get())
proc decode*(_: typedesc[AutonatMsg], buf: seq[byte]): Opt[AutonatMsg] =
proc decode*(_: typedesc[AutonatMsg], buf: seq[byte]): Option[AutonatMsg] =
var
msgTypeOrd: uint32
pbDial: ProtoBuffer
pbResponse: ProtoBuffer
msg: AutonatMsg
let pb = initProtoBuffer(buf)
let
pb = initProtoBuffer(buf)
r1 = pb.getField(1, msgTypeOrd)
r2 = pb.getField(2, pbDial)
r3 = pb.getField(3, pbResponse)
if r1.isErr() or r2.isErr() or r3.isErr(): return none(AutonatMsg)
if ?pb.getField(1, msgTypeOrd).toOpt() and
not checkedEnumAssign(msg.msgType, msgTypeOrd):
return Opt.none(AutonatMsg)
if ?pb.getField(2, pbDial).toOpt():
if r1.get() and not checkedEnumAssign(msg.msgType, msgTypeOrd):
return none(AutonatMsg)
if r2.get():
var
pbPeerInfo: ProtoBuffer
dial: AutonatDial
let r4 = ?pbDial.getField(1, pbPeerInfo).toOpt()
let
r4 = pbDial.getField(1, pbPeerInfo)
if r4.isErr(): return none(AutonatMsg)
var peerInfo: AutonatPeerInfo
if r4:
if r4.get():
var pid: PeerId
let
r5 = ?pbPeerInfo.getField(1, pid).toOpt()
r6 = ?pbPeerInfo.getRepeatedField(2, peerInfo.addrs).toOpt()
if r5:
peerInfo.id = Opt.some(pid)
dial.peerInfo = Opt.some(peerInfo)
msg.dial = Opt.some(dial)
r5 = pbPeerInfo.getField(1, pid)
r6 = pbPeerInfo.getRepeatedField(2, peerInfo.addrs)
if r5.isErr() or r6.isErr(): return none(AutonatMsg)
if r5.get(): peerInfo.id = some(pid)
dial.peerInfo = some(peerInfo)
msg.dial = some(dial)
if ?pb.getField(3, pbResponse).toOpt():
if r3.get():
var
statusOrd: uint
text: string
ma: MultiAddress
response: AutonatDialResponse
if ?pbResponse.getField(1, statusOrd).optValue():
if not checkedEnumAssign(response.status, statusOrd):
return Opt.none(AutonatMsg)
if ?pbResponse.getField(2, text).optValue():
response.text = Opt.some(text)
if ?pbResponse.getField(3, ma).optValue():
response.ma = Opt.some(ma)
msg.response = Opt.some(response)
return Opt.some(msg)
let
r4 = pbResponse.getField(1, statusOrd)
r5 = pbResponse.getField(2, text)
r6 = pbResponse.getField(3, ma)
if r4.isErr() or r5.isErr() or r6.isErr() or
(r4.get() and not checkedEnumAssign(response.status, statusOrd)):
return none(AutonatMsg)
if r5.get(): response.text = some(text)
if r6.get(): response.ma = some(ma)
msg.response = some(response)
return some(msg)

View File

@@ -7,19 +7,21 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[sets, sequtils]
import std/[options, sets, sequtils]
import stew/results
import chronos, chronicles
import
../../protocol,
../../../switch,
../../../multiaddress,
../../../multicodec,
../../../peerid,
../../../utils/[semaphore, future],
../../../errors
import chronos, chronicles, stew/objects
import ../../protocol,
../../../switch,
../../../multiaddress,
../../../multicodec,
../../../peerid,
../../../utils/[semaphore, future],
../../../errors
import core
export core
@@ -27,36 +29,33 @@ export core
logScope:
topics = "libp2p autonat"
type Autonat* = ref object of LPProtocol
sem: AsyncSemaphore
switch*: Switch
dialTimeout: Duration
type
Autonat* = ref object of LPProtocol
sem: AsyncSemaphore
switch*: Switch
dialTimeout: Duration
proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.} =
let pb = AutonatDial(
peerInfo: Opt.some(AutonatPeerInfo(id: Opt.some(pid), addrs: addrs))
).encode()
let pb = AutonatDial(peerInfo: some(AutonatPeerInfo(
id: some(pid),
addrs: addrs
))).encode()
await conn.writeLp(pb.buffer)
proc sendResponseError(
conn: Connection, status: ResponseStatus, text: string = ""
) {.async.} =
proc sendResponseError(conn: Connection, status: ResponseStatus, text: string = "") {.async.} =
let pb = AutonatDialResponse(
status: status,
text:
if text == "":
Opt.none(string)
else:
Opt.some(text)
,
ma: Opt.none(MultiAddress),
).encode()
status: status,
text: if text == "": none(string) else: some(text),
ma: none(MultiAddress)
).encode()
await conn.writeLp(pb.buffer)
proc sendResponseOk(conn: Connection, ma: MultiAddress) {.async.} =
let pb = AutonatDialResponse(
status: ResponseStatus.Ok, text: Opt.some("Ok"), ma: Opt.some(ma)
).encode()
status: ResponseStatus.Ok,
text: some("Ok"),
ma: some(ma)
).encode()
await conn.writeLp(pb.buffer)
proc tryDial(autonat: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.async.} =
@@ -64,21 +63,15 @@ proc tryDial(autonat: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.asy
var futs: seq[Future[Opt[MultiAddress]]]
try:
# This is to bypass the per peer max connections limit
let outgoingConnection =
autonat.switch.connManager.expectConnection(conn.peerId, Out)
if outgoingConnection.failed() and
outgoingConnection.error of AlreadyExpectingConnectionError:
await conn.sendResponseError(DialRefused, outgoingConnection.error.msg)
return
let outgoingConnection = autonat.switch.connManager.expectConnection(conn.peerId)
# Safer to always try to cancel cause we aren't sure if the connection was established
defer:
outgoingConnection.cancel()
defer: outgoingConnection.cancel()
# tryDial is to bypass the global max connections limit
futs = addrs.mapIt(autonat.switch.dialer.tryDial(conn.peerId, @[it]))
let fut = await anyCompleted(futs).wait(autonat.dialTimeout)
let ma = await fut
ma.withValue(maddr):
await conn.sendResponseOk(maddr)
if ma.isSome:
await conn.sendResponseOk(ma.get())
else:
await conn.sendResponseError(DialError, "Missing observed address")
except CancelledError as exc:
@@ -99,45 +92,42 @@ proc tryDial(autonat: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.asy
f.cancel()
proc handleDial(autonat: Autonat, conn: Connection, msg: AutonatMsg): Future[void] =
let dial = msg.dial.valueOr:
return conn.sendResponseError(BadRequest, "Missing Dial")
let peerInfo = dial.peerInfo.valueOr:
if msg.dial.isNone() or msg.dial.get().peerInfo.isNone():
return conn.sendResponseError(BadRequest, "Missing Peer Info")
peerInfo.id.withValue(id):
if id != conn.peerId:
return conn.sendResponseError(BadRequest, "PeerId mismatch")
let peerInfo = msg.dial.get().peerInfo.get()
if peerInfo.id.isSome() and peerInfo.id.get() != conn.peerId:
return conn.sendResponseError(BadRequest, "PeerId mismatch")
let observedAddr = conn.observedAddr.valueOr:
if conn.observedAddr.isNone:
return conn.sendResponseError(BadRequest, "Missing observed address")
let observedAddr = conn.observedAddr.get()
var isRelayed = observedAddr.contains(multiCodec("p2p-circuit")).valueOr:
return conn.sendResponseError(DialRefused, "Invalid observed address")
if isRelayed:
return
conn.sendResponseError(DialRefused, "Refused to dial a relayed observed address")
let hostIp = observedAddr[0].valueOr:
return conn.sendResponseError(InternalError, "Wrong observed address")
if not IP.match(hostIp):
var isRelayed = observedAddr.contains(multiCodec("p2p-circuit"))
if isRelayed.isErr() or isRelayed.get():
return conn.sendResponseError(DialRefused, "Refused to dial a relayed observed address")
let hostIp = observedAddr[0]
if hostIp.isErr() or not IP.match(hostIp.get()):
trace "wrong observed address", address=observedAddr
return conn.sendResponseError(InternalError, "Expected an IP address")
var addrs = initHashSet[MultiAddress]()
addrs.incl(observedAddr)
trace "addrs received", addrs = peerInfo.addrs
for ma in peerInfo.addrs:
isRelayed = ma.contains(multiCodec("p2p-circuit")).valueOr:
isRelayed = ma.contains(multiCodec("p2p-circuit"))
if isRelayed.isErr() or isRelayed.get():
continue
let maFirst = ma[0].valueOr:
continue
if not DNS_OR_IP.match(maFirst):
let maFirst = ma[0]
if maFirst.isErr() or not DNS_OR_IP.match(maFirst.get()):
continue
try:
addrs.incl(
if maFirst == hostIp:
if maFirst.get() == hostIp.get():
ma
else:
let maEnd = ma[1 ..^ 1].valueOr:
continue
hostIp & maEnd
let maEnd = ma[1..^1]
if maEnd.isErr(): continue
hostIp.get() & maEnd.get()
)
except LPError as exc:
continue
@@ -150,17 +140,14 @@ proc handleDial(autonat: Autonat, conn: Connection, msg: AutonatMsg): Future[voi
trace "trying to dial", addrs = addrsSeq
return autonat.tryDial(conn, addrsSeq)
proc new*(
T: typedesc[Autonat], switch: Switch, semSize: int = 1, dialTimeout = 15.seconds
): T =
let autonat =
T(switch: switch, sem: newAsyncSemaphore(semSize), dialTimeout: dialTimeout)
proc handleStream(conn: Connection, proto: string) {.async.} =
proc new*(T: typedesc[Autonat], switch: Switch, semSize: int = 1, dialTimeout = 15.seconds): T =
let autonat = T(switch: switch, sem: newAsyncSemaphore(semSize), dialTimeout: dialTimeout)
proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} =
try:
let msg = AutonatMsg.decode(await conn.readLp(1024)).valueOr:
let msgOpt = AutonatMsg.decode(await conn.readLp(1024))
if msgOpt.isNone() or msgOpt.get().msgType != MsgType.Dial:
raise newException(AutonatError, "Received malformed message")
if msg.msgType != MsgType.Dial:
raise newException(AutonatError, "Message type should be dial")
let msg = msgOpt.get()
await autonat.handleDial(conn, msg)
except CancelledError as exc:
raise exc

View File

@@ -7,67 +7,59 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[deques, sequtils]
import std/[options, deques, sequtils]
import chronos, metrics
import ../../../switch
import ../../../wire
import client
from core import NetworkReachability, AutonatUnreachableError
import ../../../utils/heartbeat
import ../../../crypto/crypto
export core.NetworkReachability
logScope:
topics = "libp2p autonatservice"
declarePublicGauge(
libp2p_autonat_reachability_confidence,
"autonat reachability confidence",
labels = ["reachability"],
)
declarePublicGauge(libp2p_autonat_reachability_confidence, "autonat reachability confidence", labels = ["reachability"])
type
AutonatService* = ref object of Service
newConnectedPeerHandler: PeerEventHandler
addressMapper: AddressMapper
scheduleHandle: Future[void]
networkReachability*: NetworkReachability
confidence: Opt[float]
networkReachability: NetworkReachability
confidence: Option[float]
answers: Deque[NetworkReachability]
autonatClient: AutonatClient
statusAndConfidenceHandler: StatusAndConfidenceHandler
rng: ref HmacDrbgContext
scheduleInterval: Opt[Duration]
scheduleInterval: Option[Duration]
askNewConnectedPeers: bool
numPeersToAsk: int
maxQueueSize: int
minConfidence: float
dialTimeout: Duration
enableAddressMapper: bool
StatusAndConfidenceHandler* = proc(
networkReachability: NetworkReachability, confidence: Opt[float]
): Future[void] {.gcsafe, raises: [].}
NetworkReachability* {.pure.} = enum
NotReachable, Reachable, Unknown
StatusAndConfidenceHandler* = proc (networkReachability: NetworkReachability, confidence: Option[float]): Future[void] {.gcsafe, raises: [Defect].}
proc new*(
T: typedesc[AutonatService],
autonatClient: AutonatClient,
rng: ref HmacDrbgContext,
scheduleInterval: Opt[Duration] = Opt.none(Duration),
askNewConnectedPeers = true,
numPeersToAsk: int = 5,
maxQueueSize: int = 10,
minConfidence: float = 0.3,
dialTimeout = 30.seconds,
enableAddressMapper = true,
): T =
T: typedesc[AutonatService],
autonatClient: AutonatClient,
rng: ref HmacDrbgContext,
scheduleInterval: Option[Duration] = none(Duration),
askNewConnectedPeers = true,
numPeersToAsk: int = 5,
maxQueueSize: int = 10,
minConfidence: float = 0.3,
dialTimeout = 30.seconds): T =
return T(
scheduleInterval: scheduleInterval,
networkReachability: Unknown,
confidence: Opt.none(float),
confidence: none(float),
answers: initDeque[NetworkReachability](),
autonatClient: autonatClient,
rng: rng,
@@ -75,9 +67,10 @@ proc new*(
numPeersToAsk: numPeersToAsk,
maxQueueSize: maxQueueSize,
minConfidence: minConfidence,
dialTimeout: dialTimeout,
enableAddressMapper: enableAddressMapper,
)
dialTimeout: dialTimeout)
proc networkReachability*(self: AutonatService): NetworkReachability {.inline.} =
return self.networkReachability
proc callHandler(self: AutonatService) {.async.} =
if not isNil(self.statusAndConfidenceHandler):
@@ -87,56 +80,32 @@ proc hasEnoughIncomingSlots(switch: Switch): bool =
# we leave some margin instead of comparing to 0 as a peer could connect to us while we are asking for the dial back
return switch.connManager.slotsAvailable(In) >= 2
proc doesPeerHaveIncomingConn(switch: Switch, peerId: PeerId): bool =
return switch.connManager.selectMuxer(peerId, In) != nil
proc handleAnswer(self: AutonatService, ans: NetworkReachability) {.async.} =
proc handleAnswer(
self: AutonatService, ans: NetworkReachability
): Future[bool] {.async.} =
if ans == Unknown:
return
let oldNetworkReachability = self.networkReachability
let oldConfidence = self.confidence
if self.answers.len == self.maxQueueSize:
self.answers.popFirst()
self.answers.addLast(ans)
self.networkReachability = Unknown
self.confidence = Opt.none(float)
self.confidence = none(float)
const reachabilityPriority = [Reachable, NotReachable]
for reachability in reachabilityPriority:
let confidence = self.answers.countIt(it == reachability) / self.maxQueueSize
libp2p_autonat_reachability_confidence.set(
value = confidence, labelValues = [$reachability]
)
libp2p_autonat_reachability_confidence.set(value = confidence, labelValues = [$reachability])
if self.confidence.isNone and confidence >= self.minConfidence:
self.networkReachability = reachability
self.confidence = Opt.some(confidence)
self.confidence = some(confidence)
debug "Current status",
currentStats = $self.networkReachability,
confidence = $self.confidence,
answers = self.answers
debug "Current status", currentStats = $self.networkReachability, confidence = $self.confidence, answers = self.answers
# Return whether anything has changed
return
self.networkReachability != oldNetworkReachability or
self.confidence != oldConfidence
proc askPeer(
self: AutonatService, switch: Switch, peerId: PeerId
): Future[NetworkReachability] {.async.} =
proc askPeer(self: AutonatService, switch: Switch, peerId: PeerId): Future[NetworkReachability] {.async.} =
logScope:
peerId = $peerId
if doesPeerHaveIncomingConn(switch, peerId):
return Unknown
if not hasEnoughIncomingSlots(switch):
debug "No incoming slots available, not asking peer",
incomingSlotsAvailable = switch.connManager.slotsAvailable(In)
debug "No incoming slots available, not asking peer", incomingSlotsAvailable=switch.connManager.slotsAvailable(In)
return Unknown
trace "Asking peer for reachability"
@@ -154,10 +123,9 @@ proc askPeer(
except CatchableError as error:
debug "dialMe unexpected error", msg = error.msg
Unknown
let hasReachabilityOrConfidenceChanged = await self.handleAnswer(ans)
if hasReachabilityOrConfidenceChanged:
await self.callHandler()
await switch.peerInfo.update()
await self.handleAnswer(ans)
if not isNil(self.statusAndConfidenceHandler):
await self.statusAndConfidenceHandler(self.networkReachability, self.confidence)
return ans
proc askConnectedPeers(self: AutonatService, switch: Switch) {.async.} =
@@ -169,8 +137,7 @@ proc askConnectedPeers(self: AutonatService, switch: Switch) {.async.} =
if answersFromPeers >= self.numPeersToAsk:
break
if not hasEnoughIncomingSlots(switch):
debug "No incoming slots available, not asking peers",
incomingSlotsAvailable = switch.connManager.slotsAvailable(In)
debug "No incoming slots available, not asking peers", incomingSlotsAvailable=switch.connManager.slotsAvailable(In)
break
if (await askPeer(self, switch, peer)) != Unknown:
answersFromPeers.inc()
@@ -179,51 +146,26 @@ proc schedule(service: AutonatService, switch: Switch, interval: Duration) {.asy
heartbeat "Scheduling AutonatService run", interval:
await service.run(switch)
proc addressMapper(
self: AutonatService, peerStore: PeerStore, listenAddrs: seq[MultiAddress]
): Future[seq[MultiAddress]] {.async.} =
if self.networkReachability != NetworkReachability.Reachable:
return listenAddrs
var addrs = newSeq[MultiAddress]()
for listenAddr in listenAddrs:
var processedMA = listenAddr
try:
if not listenAddr.isPublicMA() and
self.networkReachability == NetworkReachability.Reachable:
processedMA = peerStore.guessDialableAddr(listenAddr)
# handle manual port forwarding
except CatchableError as exc:
debug "Error while handling address mapper", msg = exc.msg
addrs.add(processedMA)
return addrs
method setup*(self: AutonatService, switch: Switch): Future[bool] {.async.} =
self.addressMapper = proc(
listenAddrs: seq[MultiAddress]
): Future[seq[MultiAddress]] {.async.} =
return await addressMapper(self, switch.peerStore, listenAddrs)
info "Setting up AutonatService"
let hasBeenSetup = await procCall Service(self).setup(switch)
if hasBeenSetup:
if self.askNewConnectedPeers:
self.newConnectedPeerHandler = proc(
peerId: PeerId, event: PeerEvent
): Future[void] {.async.} =
self.newConnectedPeerHandler = proc (peerId: PeerId, event: PeerEvent): Future[void] {.async.} =
if switch.connManager.selectConn(peerId, In) != nil: # no need to ask an incoming peer
return
discard askPeer(self, switch, peerId)
switch.connManager.addPeerEventHandler(
self.newConnectedPeerHandler, PeerEventKind.Joined
)
self.scheduleInterval.withValue(interval):
self.scheduleHandle = schedule(self, switch, interval)
if self.enableAddressMapper:
switch.peerInfo.addressMappers.add(self.addressMapper)
await self.callHandler()
switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)
if self.scheduleInterval.isSome():
self.scheduleHandle = schedule(self, switch, self.scheduleInterval.get())
return hasBeenSetup
method run*(self: AutonatService, switch: Switch) {.async, public.} =
trace "Running AutonatService"
await askConnectedPeers(self, switch)
await self.callHandler()
method stop*(self: AutonatService, switch: Switch): Future[bool] {.async, public.} =
info "Stopping AutonatService"
@@ -233,15 +175,8 @@ method stop*(self: AutonatService, switch: Switch): Future[bool] {.async, public
self.scheduleHandle.cancel()
self.scheduleHandle = nil
if not isNil(self.newConnectedPeerHandler):
switch.connManager.removePeerEventHandler(
self.newConnectedPeerHandler, PeerEventKind.Joined
)
if self.enableAddressMapper:
switch.peerInfo.addressMappers.keepItIf(it != self.addressMapper)
await switch.peerInfo.update()
switch.connManager.removePeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)
return hasBeenStopped
proc statusAndConfidenceHandler*(
self: AutonatService, statusAndConfidenceHandler: StatusAndConfidenceHandler
) =
proc statusAndConfidenceHandler*(self: AutonatService, statusAndConfidenceHandler: StatusAndConfidenceHandler) =
self.statusAndConfidenceHandler = statusAndConfidenceHandler

View File

@@ -1,114 +0,0 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import std/sequtils
import stew/results
import chronos, chronicles
import core
import
../../protocol, ../../../stream/connection, ../../../switch, ../../../utils/future
export DcutrError
type DcutrClient* = ref object
connectTimeout: Duration
maxDialableAddrs: int
logScope:
topics = "libp2p dcutrclient"
proc new*(
T: typedesc[DcutrClient], connectTimeout = 15.seconds, maxDialableAddrs = 8
): T =
return T(connectTimeout: connectTimeout, maxDialableAddrs: maxDialableAddrs)
proc startSync*(
self: DcutrClient, switch: Switch, remotePeerId: PeerId, addrs: seq[MultiAddress]
) {.async.} =
logScope:
peerId = switch.peerInfo.peerId
var
peerDialableAddrs: seq[MultiAddress]
stream: Connection
try:
var ourDialableAddrs = getHolePunchableAddrs(addrs)
if ourDialableAddrs.len == 0:
debug "Dcutr initiator has no supported dialable addresses. Aborting Dcutr.",
addrs
return
stream = await switch.dial(remotePeerId, DcutrCodec)
await stream.send(MsgType.Connect, addrs)
debug "Dcutr initiator has sent a Connect message."
let rttStart = Moment.now()
let connectAnswer = DcutrMsg.decode(await stream.readLp(1024))
peerDialableAddrs = getHolePunchableAddrs(connectAnswer.addrs)
if peerDialableAddrs.len == 0:
debug "Dcutr receiver has no supported dialable addresses to connect to. Aborting Dcutr.",
addrs = connectAnswer.addrs
return
let rttEnd = Moment.now()
debug "Dcutr initiator has received a Connect message back.", connectAnswer
let halfRtt = (rttEnd - rttStart) div 2'i64
await stream.send(MsgType.Sync, @[])
debug "Dcutr initiator has sent a Sync message."
await sleepAsync(halfRtt)
if peerDialableAddrs.len > self.maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0 ..< self.maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(
switch.connect(
stream.peerId,
@[it],
forceDial = true,
reuseConnection = false,
dir = Direction.In,
)
)
try:
discard await anyCompleted(futs).wait(self.connectTimeout)
debug "Dcutr initiator has directly connected to the remote peer."
finally:
for fut in futs:
fut.cancel()
except CancelledError as err:
raise err
except AllFuturesFailedError as err:
debug "Dcutr initiator could not connect to the remote peer, all connect attempts failed",
peerDialableAddrs, msg = err.msg
raise newException(
DcutrError,
"Dcutr initiator could not connect to the remote peer, all connect attempts failed",
err,
)
except AsyncTimeoutError as err:
debug "Dcutr initiator could not connect to the remote peer, all connect attempts timed out",
peerDialableAddrs, msg = err.msg
raise newException(
DcutrError,
"Dcutr initiator could not connect to the remote peer, all connect attempts timed out",
err,
)
except CatchableError as err:
debug "Unexpected error when Dcutr initiator tried to connect to the remote peer",
err = err.msg
raise newException(
DcutrError,
"Unexpected error when Dcutr initiator tried to connect to the remote peer", err,
)
finally:
if stream != nil:
await stream.close()

View File

@@ -1,64 +0,0 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import std/sequtils
import chronos
import stew/objects
import ../../../multiaddress, ../../../errors, ../../../stream/connection
export multiaddress
const DcutrCodec* = "/libp2p/dcutr"
type
MsgType* = enum
Connect = 100
Sync = 300
DcutrMsg* = object
msgType*: MsgType
addrs*: seq[MultiAddress]
DcutrError* = object of LPError
proc encode*(msg: DcutrMsg): ProtoBuffer =
result = initProtoBuffer()
result.write(1, msg.msgType.uint)
for addr in msg.addrs:
result.write(2, addr)
result.finish()
proc decode*(_: typedesc[DcutrMsg], buf: seq[byte]): DcutrMsg {.raises: [DcutrError].} =
var
msgTypeOrd: uint32
dcutrMsg: DcutrMsg
var pb = initProtoBuffer(buf)
var r1 = pb.getField(1, msgTypeOrd)
let r2 = pb.getRepeatedField(2, dcutrMsg.addrs)
if r1.isErr or r2.isErr or not checkedEnumAssign(dcutrMsg.msgType, msgTypeOrd):
raise newException(DcutrError, "Received malformed message")
return dcutrMsg
proc send*(conn: Connection, msgType: MsgType, addrs: seq[MultiAddress]) {.async.} =
let pb = DcutrMsg(msgType: msgType, addrs: addrs).encode()
await conn.writeLp(pb.buffer)
proc getHolePunchableAddrs*(
addrs: seq[MultiAddress]
): seq[MultiAddress] {.raises: [LPError].} =
var result = newSeq[MultiAddress]()
for a in addrs:
# This is necessary to also accept addrs like /ip4/198.51.100/tcp/1234/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N
if [TCP, mapAnd(TCP_DNS, P2PPattern), mapAnd(TCP_IP, P2PPattern)].anyIt(it.match(a)):
result.add(a[0 .. 1].tryGet())
return result

View File

@@ -1,94 +0,0 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import std/[sets, sequtils]
import stew/[results, objects]
import chronos, chronicles
import core
import
../../protocol, ../../../stream/connection, ../../../switch, ../../../utils/future
export chronicles
type Dcutr* = ref object of LPProtocol
logScope:
topics = "libp2p dcutr"
proc new*(
T: typedesc[Dcutr],
switch: Switch,
connectTimeout = 15.seconds,
maxDialableAddrs = 8,
): T =
proc handleStream(stream: Connection, proto: string) {.async.} =
var peerDialableAddrs: seq[MultiAddress]
try:
let connectMsg = DcutrMsg.decode(await stream.readLp(1024))
debug "Dcutr receiver received a Connect message.", connectMsg
var ourAddrs = switch.peerStore.getMostObservedProtosAndPorts()
# likely empty when the peer is reachable
if ourAddrs.len == 0:
# this list should be the same as the peer's public addrs when it is reachable
ourAddrs =
switch.peerInfo.listenAddrs.mapIt(switch.peerStore.guessDialableAddr(it))
var ourDialableAddrs = getHolePunchableAddrs(ourAddrs)
if ourDialableAddrs.len == 0:
debug "Dcutr receiver has no supported dialable addresses. Aborting Dcutr.",
ourAddrs
return
await stream.send(MsgType.Connect, ourAddrs)
debug "Dcutr receiver has sent a Connect message back."
let syncMsg = DcutrMsg.decode(await stream.readLp(1024))
debug "Dcutr receiver has received a Sync message.", syncMsg
peerDialableAddrs = getHolePunchableAddrs(connectMsg.addrs)
if peerDialableAddrs.len == 0:
debug "Dcutr initiator has no supported dialable addresses to connect to. Aborting Dcutr.",
addrs = connectMsg.addrs
return
if peerDialableAddrs.len > maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0 ..< maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(
switch.connect(
stream.peerId,
@[it],
forceDial = true,
reuseConnection = false,
dir = Direction.Out,
)
)
try:
discard await anyCompleted(futs).wait(connectTimeout)
debug "Dcutr receiver has directly connected to the remote peer."
finally:
for fut in futs:
fut.cancel()
except CancelledError as err:
raise err
except AllFuturesFailedError as err:
debug "Dcutr receiver could not connect to the remote peer, " &
"all connect attempts failed", peerDialableAddrs, msg = err.msg
except AsyncTimeoutError as err:
debug "Dcutr receiver could not connect to the remote peer, " &
"all connect attempts timed out", peerDialableAddrs, msg = err.msg
except CatchableError as err:
warn "Unexpected error when Dcutr receiver tried to connect " &
"to the remote peer", msg = err.msg
let self = T()
self.handler = handleStream
self.codec = DcutrCodec
self

View File

@@ -7,19 +7,25 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import times, options
import times
import chronos, chronicles
import
./relay,
./messages,
./rconn,
./utils,
../../../peerinfo,
../../../switch,
../../../multiaddress,
../../../stream/connection
import ./relay,
./messages,
./rconn,
./utils,
../../../peerinfo,
../../../switch,
../../../multiaddress,
../../../stream/connection
export options
logScope:
topics = "libp2p relay relay-client"
@@ -31,37 +37,38 @@ type
ReservationError* = object of RelayClientError
RelayV1DialError* = object of RelayClientError
RelayV2DialError* = object of RelayClientError
RelayClientAddConn* = proc(
conn: Connection, duration: uint32, data: uint64
): Future[void] {.gcsafe, raises: [].}
RelayClientAddConn* = proc(conn: Connection,
duration: uint32,
data: uint64): Future[void] {.gcsafe, raises: [Defect].}
RelayClient* = ref object of Relay
onNewConnection*: RelayClientAddConn
canHop: bool
Rsvp* = object
expire*: uint64 # required, Unix expiration time (UTC)
expire*: uint64 # required, Unix expiration time (UTC)
addrs*: seq[MultiAddress] # relay address for reserving peer
voucher*: Opt[Voucher] # optional, reservation voucher
limitDuration*: uint32 # seconds
limitData*: uint64 # bytes
voucher*: Option[Voucher] # optional, reservation voucher
limitDuration*: uint32 # seconds
limitData*: uint64 # bytes
proc sendStopError(conn: Connection, code: StatusV2) {.async.} =
trace "send stop status", status = $code & " (" & $ord(code) & ")"
let msg = StopMessage(msgType: StopMessageType.Status, status: Opt.some(code))
let msg = StopMessage(msgType: StopMessageType.Status, status: some(code))
await conn.writeLp(encode(msg).buffer)
proc handleRelayedConnect(
cl: RelayClient, conn: Connection, msg: StopMessage
) {.async.} =
proc handleRelayedConnect(cl: RelayClient, conn: Connection, msg: StopMessage) {.async.} =
if msg.peer.isNone():
await sendStopError(conn, MalformedMessage)
return
let
# TODO: check the go version to see in which way this could fail
# it's unclear in the spec
src = msg.peer.valueOr:
await sendStopError(conn, MalformedMessage)
return
src = msg.peer.get()
limitDuration = msg.limit.duration
limitData = msg.limit.data
msg = StopMessage(msgType: StopMessageType.Status, status: Opt.some(Ok))
msg = StopMessage(
msgType: StopMessageType.Status,
status: some(Ok))
pb = encode(msg)
trace "incoming relay connection", src
@@ -73,156 +80,146 @@ proc handleRelayedConnect(
await conn.writeLp(pb.buffer)
# This sound redundant but the callback could, in theory, be set to nil during
# conn.writeLp so it's safer to double check
if cl.onNewConnection != nil:
await cl.onNewConnection(conn, limitDuration, limitData)
else:
await conn.close()
if cl.onNewConnection != nil: await cl.onNewConnection(conn, limitDuration, limitData)
else: await conn.close()
proc reserve*(
cl: RelayClient, peerId: PeerId, addrs: seq[MultiAddress] = @[]
): Future[Rsvp] {.async.} =
proc reserve*(cl: RelayClient,
peerId: PeerId,
addrs: seq[MultiAddress] = @[]): Future[Rsvp] {.async.} =
let conn = await cl.switch.dial(peerId, addrs, RelayV2HopCodec)
defer:
await conn.close()
defer: await conn.close()
let
pb = encode(HopMessage(msgType: HopMessageType.Reserve))
msg =
try:
await conn.writeLp(pb.buffer)
HopMessage.decode(await conn.readLp(RelayClientMsgSize)).tryGet()
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error writing or reading reservation message", exc = exc.msg
raise newException(ReservationError, exc.msg)
msg = try:
await conn.writeLp(pb.buffer)
HopMessage.decode(await conn.readLp(RelayClientMsgSize)).get()
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error writing or reading reservation message", exc=exc.msg
raise newException(ReservationError, exc.msg)
if msg.msgType != HopMessageType.Status:
raise newException(ReservationError, "Unexpected relay response type")
if msg.status.get(UnexpectedMessage) != Ok:
raise newException(ReservationError, "Reservation failed")
let reservation = msg.reservation.valueOr:
if msg.reservation.isNone():
raise newException(ReservationError, "Missing reservation information")
let reservation = msg.reservation.get()
if reservation.expire > int64.high().uint64 or
now().utc > reservation.expire.int64.fromUnix.utc:
now().utc > reservation.expire.int64.fromUnix.utc:
raise newException(ReservationError, "Bad expiration date")
result.expire = reservation.expire
result.addrs = reservation.addrs
reservation.svoucher.withValue(sv):
let svoucher = SignedVoucher.decode(sv).valueOr:
if reservation.svoucher.isSome():
let svoucher = SignedVoucher.decode(reservation.svoucher.get())
if svoucher.isErr() or svoucher.get().data.relayPeerId != peerId:
raise newException(ReservationError, "Invalid voucher")
if svoucher.data.relayPeerId != peerId:
raise newException(ReservationError, "Invalid voucher PeerId")
result.voucher = Opt.some(svoucher.data)
result.voucher = some(svoucher.get().data)
result.limitDuration = msg.limit.duration
result.limitData = msg.limit.data
proc dialPeerV1*(
cl: RelayClient, conn: Connection, dstPeerId: PeerId, dstAddrs: seq[MultiAddress]
): Future[Connection] {.async.} =
cl: RelayClient,
conn: Connection,
dstPeerId: PeerId,
dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} =
var
msg = RelayMessage(
msgType: Opt.some(RelayType.Hop),
srcPeer: Opt.some(
RelayPeer(peerId: cl.switch.peerInfo.peerId, addrs: cl.switch.peerInfo.addrs)
),
dstPeer: Opt.some(RelayPeer(peerId: dstPeerId, addrs: dstAddrs)),
)
msgType: some(RelayType.Hop),
srcPeer: some(RelayPeer(peerId: cl.switch.peerInfo.peerId, addrs: cl.switch.peerInfo.addrs)),
dstPeer: some(RelayPeer(peerId: dstPeerId, addrs: dstAddrs)))
pb = encode(msg)
trace "Dial peer", msgSend = msg
trace "Dial peer", msgSend=msg
try:
await conn.writeLp(pb.buffer)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error writing hop request", exc = exc.msg
trace "error writing hop request", exc=exc.msg
raise exc
let msgRcvFromRelayOpt =
try:
RelayMessage.decode(await conn.readLp(RelayClientMsgSize))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error reading stop response", exc = exc.msg
await sendStatus(conn, StatusV1.HopCantOpenDstStream)
raise exc
let msgRcvFromRelayOpt = try:
RelayMessage.decode(await conn.readLp(RelayClientMsgSize))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error reading stop response", exc=exc.msg
await sendStatus(conn, StatusV1.HopCantOpenDstStream)
raise exc
try:
let msgRcvFromRelay = msgRcvFromRelayOpt.valueOr:
if msgRcvFromRelayOpt.isNone:
raise newException(RelayV1DialError, "Hop can't open destination stream")
if msgRcvFromRelay.msgType.tryGet() != RelayType.Status:
raise newException(
RelayV1DialError, "Hop can't open destination stream: wrong message type"
)
if msgRcvFromRelay.status.tryGet() != StatusV1.Success:
raise newException(
RelayV1DialError, "Hop can't open destination stream: status failed"
)
let msgRcvFromRelay = msgRcvFromRelayOpt.get()
if msgRcvFromRelay.msgType.isNone or msgRcvFromRelay.msgType.get() != RelayType.Status:
raise newException(RelayV1DialError, "Hop can't open destination stream: wrong message type")
if msgRcvFromRelay.status.isNone or msgRcvFromRelay.status.get() != StatusV1.Success:
raise newException(RelayV1DialError, "Hop can't open destination stream: status failed")
except RelayV1DialError as exc:
await sendStatus(conn, StatusV1.HopCantOpenDstStream)
raise exc
except ValueError as exc:
await sendStatus(conn, StatusV1.HopCantOpenDstStream)
raise newException(RelayV1DialError, exc.msg)
result = conn
proc dialPeerV2*(
cl: RelayClient,
conn: RelayConnection,
dstPeerId: PeerId,
dstAddrs: seq[MultiAddress],
): Future[Connection] {.async.} =
dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} =
let
p = Peer(peerId: dstPeerId, addrs: dstAddrs)
pb = encode(HopMessage(msgType: HopMessageType.Connect, peer: Opt.some(p)))
pb = encode(HopMessage(msgType: HopMessageType.Connect, peer: some(p)))
trace "Dial peer", p
let msgRcvFromRelay =
try:
await conn.writeLp(pb.buffer)
HopMessage.decode(await conn.readLp(RelayClientMsgSize)).tryGet()
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error reading stop response", exc = exc.msg
raise newException(RelayV2DialError, exc.msg)
let msgRcvFromRelay = try:
await conn.writeLp(pb.buffer)
HopMessage.decode(await conn.readLp(RelayClientMsgSize)).get()
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error reading stop response", exc=exc.msg
raise newException(RelayV2DialError, exc.msg)
if msgRcvFromRelay.msgType != HopMessageType.Status:
raise newException(RelayV2DialError, "Unexpected stop response")
if msgRcvFromRelay.status.get(UnexpectedMessage) != Ok:
trace "Relay stop failed", msg = msgRcvFromRelay.status
trace "Relay stop failed", msg = msgRcvFromRelay.status.get()
raise newException(RelayV2DialError, "Relay stop failure")
conn.limitDuration = msgRcvFromRelay.limit.duration
conn.limitData = msgRcvFromRelay.limit.data
return conn
proc handleStopStreamV2(cl: RelayClient, conn: Connection) {.async.} =
let msg = StopMessage.decode(await conn.readLp(RelayClientMsgSize)).valueOr:
proc handleStopStreamV2(cl: RelayClient, conn: Connection) {.async, gcsafe.} =
let msgOpt = StopMessage.decode(await conn.readLp(RelayClientMsgSize))
if msgOpt.isNone():
await sendHopStatus(conn, MalformedMessage)
return
trace "client circuit relay v2 handle stream", msg
trace "client circuit relay v2 handle stream", msg = msgOpt.get()
let msg = msgOpt.get()
if msg.msgType == StopMessageType.Connect:
await cl.handleRelayedConnect(conn, msg)
else:
trace "Unexpected client / relayv2 handshake", msgType = msg.msgType
trace "Unexpected client / relayv2 handshake", msgType=msg.msgType
await sendStopError(conn, MalformedMessage)
proc handleStop(cl: RelayClient, conn: Connection, msg: RelayMessage) {.async.} =
let src = msg.srcPeer.valueOr:
proc handleStop(cl: RelayClient, conn: Connection, msg: RelayMessage) {.async, gcsafe.} =
if msg.srcPeer.isNone:
await sendStatus(conn, StatusV1.StopSrcMultiaddrInvalid)
return
let src = msg.srcPeer.get()
let dst = msg.dstPeer.valueOr:
if msg.dstPeer.isNone:
await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid)
return
let dst = msg.dstPeer.get()
if dst.peerId != cl.switch.peerInfo.peerId:
await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid)
return
@@ -236,69 +233,52 @@ proc handleStop(cl: RelayClient, conn: Connection, msg: RelayMessage) {.async.}
await sendStatus(conn, StatusV1.Success)
# This sound redundant but the callback could, in theory, be set to nil during
# sendStatus(Success) so it's safer to double check
if cl.onNewConnection != nil:
await cl.onNewConnection(conn, 0, 0)
else:
await conn.close()
if cl.onNewConnection != nil: await cl.onNewConnection(conn, 0, 0)
else: await conn.close()
proc handleStreamV1(cl: RelayClient, conn: Connection) {.async.} =
let msg = RelayMessage.decode(await conn.readLp(RelayClientMsgSize)).valueOr:
proc handleStreamV1(cl: RelayClient, conn: Connection) {.async, gcsafe.} =
let msgOpt = RelayMessage.decode(await conn.readLp(RelayClientMsgSize))
if msgOpt.isNone:
await sendStatus(conn, StatusV1.MalformedMessage)
return
trace "client circuit relay v1 handle stream", msg
let typ = msg.msgType.valueOr:
trace "Message type not set"
await sendStatus(conn, StatusV1.MalformedMessage)
return
case typ
of RelayType.Hop:
if cl.canHop:
await cl.handleHop(conn, msg)
trace "client circuit relay v1 handle stream", msg = msgOpt.get()
let msg = msgOpt.get()
case msg.msgType.get:
of RelayType.Hop:
if cl.canHop: await cl.handleHop(conn, msg)
else: await sendStatus(conn, StatusV1.HopCantSpeakRelay)
of RelayType.Stop: await cl.handleStop(conn, msg)
of RelayType.CanHop:
if cl.canHop: await sendStatus(conn, StatusV1.Success)
else: await sendStatus(conn, StatusV1.HopCantSpeakRelay)
else:
await sendStatus(conn, StatusV1.HopCantSpeakRelay)
of RelayType.Stop:
await cl.handleStop(conn, msg)
of RelayType.CanHop:
if cl.canHop:
await sendStatus(conn, StatusV1.Success)
else:
await sendStatus(conn, StatusV1.HopCantSpeakRelay)
else:
trace "Unexpected relay handshake", msgType = msg.msgType
await sendStatus(conn, StatusV1.MalformedMessage)
trace "Unexpected relay handshake", msgType=msg.msgType
await sendStatus(conn, StatusV1.MalformedMessage)
proc new*(
T: typedesc[RelayClient],
canHop: bool = false,
reservationTTL: times.Duration = DefaultReservationTTL,
limitDuration: uint32 = DefaultLimitDuration,
limitData: uint64 = DefaultLimitData,
heartbeatSleepTime: uint32 = DefaultHeartbeatSleepTime,
maxCircuit: int = MaxCircuit,
maxCircuitPerPeer: int = MaxCircuitPerPeer,
msgSize: int = RelayClientMsgSize,
circuitRelayV1: bool = false,
): T =
let cl = T(
canHop: canHop,
reservationTTL: reservationTTL,
limit: Limit(duration: limitDuration, data: limitData),
heartbeatSleepTime: heartbeatSleepTime,
maxCircuit: maxCircuit,
maxCircuitPerPeer: maxCircuitPerPeer,
msgSize: msgSize,
isCircuitRelayV1: circuitRelayV1,
)
proc handleStream(conn: Connection, proto: string) {.async.} =
proc new*(T: typedesc[RelayClient], canHop: bool = false,
reservationTTL: times.Duration = DefaultReservationTTL,
limitDuration: uint32 = DefaultLimitDuration,
limitData: uint64 = DefaultLimitData,
heartbeatSleepTime: uint32 = DefaultHeartbeatSleepTime,
maxCircuit: int = MaxCircuit,
maxCircuitPerPeer: int = MaxCircuitPerPeer,
msgSize: int = RelayClientMsgSize,
circuitRelayV1: bool = false): T =
let cl = T(canHop: canHop,
reservationTTL: reservationTTL,
limit: Limit(duration: limitDuration, data: limitData),
heartbeatSleepTime: heartbeatSleepTime,
maxCircuit: maxCircuit,
maxCircuitPerPeer: maxCircuitPerPeer,
msgSize: msgSize,
isCircuitRelayV1: circuitRelayV1)
proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} =
try:
case proto
of RelayV1Codec:
await cl.handleStreamV1(conn)
of RelayV2StopCodec:
await cl.handleStopStreamV2(conn)
of RelayV2HopCodec:
await cl.handleHopStreamV2(conn)
case proto:
of RelayV1Codec: await cl.handleStreamV1(conn)
of RelayV2StopCodec: await cl.handleStopStreamV2(conn)
of RelayV2HopCodec: await cl.handleHopStreamV2(conn)
except CancelledError as exc:
raise exc
except CatchableError as exc:
@@ -308,9 +288,8 @@ proc new*(
await conn.close()
cl.handler = handleStream
cl.codecs =
if cl.canHop:
@[RelayV1Codec, RelayV2HopCodec, RelayV2StopCodec]
else:
@[RelayV1Codec, RelayV2StopCodec]
cl.codecs = if cl.canHop:
@[RelayV1Codec, RelayV2HopCodec, RelayV2StopCodec]
else:
@[RelayV1Codec, RelayV2StopCodec]
cl

View File

@@ -7,11 +7,15 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import macros
import stew/[objects, results]
import ../../../peerinfo, ../../../signed_envelope
import options, macros, sequtils
import stew/objects
import ../../../peerinfo,
../../../signed_envelope
# Circuit Relay V1 Message
@@ -45,36 +49,36 @@ type
addrs*: seq[MultiAddress]
RelayMessage* = object
msgType*: Opt[RelayType]
srcPeer*: Opt[RelayPeer]
dstPeer*: Opt[RelayPeer]
status*: Opt[StatusV1]
msgType*: Option[RelayType]
srcPeer*: Option[RelayPeer]
dstPeer*: Option[RelayPeer]
status*: Option[StatusV1]
proc encode*(msg: RelayMessage): ProtoBuffer =
result = initProtoBuffer()
msg.msgType.withValue(typ):
result.write(1, typ.ord.uint)
msg.srcPeer.withValue(srcPeer):
if isSome(msg.msgType):
result.write(1, msg.msgType.get().ord.uint)
if isSome(msg.srcPeer):
var peer = initProtoBuffer()
peer.write(1, srcPeer.peerId)
for ma in srcPeer.addrs:
peer.write(1, msg.srcPeer.get().peerId)
for ma in msg.srcPeer.get().addrs:
peer.write(2, ma.data.buffer)
peer.finish()
result.write(2, peer.buffer)
msg.dstPeer.withValue(dstPeer):
if isSome(msg.dstPeer):
var peer = initProtoBuffer()
peer.write(1, dstPeer.peerId)
for ma in dstPeer.addrs:
peer.write(1, msg.dstPeer.get().peerId)
for ma in msg.dstPeer.get().addrs:
peer.write(2, ma.data.buffer)
peer.finish()
result.write(3, peer.buffer)
msg.status.withValue(status):
result.write(4, status.ord.uint)
if isSome(msg.status):
result.write(4, msg.status.get().ord.uint)
result.finish()
proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Opt[RelayMessage] =
proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Option[RelayMessage] =
var
rMsg: RelayMessage
msgTypeOrd: uint32
@@ -84,44 +88,53 @@ proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Opt[RelayMessage] =
pbSrc: ProtoBuffer
pbDst: ProtoBuffer
let pb = initProtoBuffer(buf)
let
pb = initProtoBuffer(buf)
r1 = pb.getField(1, msgTypeOrd)
r2 = pb.getField(2, pbSrc)
r3 = pb.getField(3, pbDst)
r4 = pb.getField(4, statusOrd)
if ?pb.getField(1, msgTypeOrd).toOpt():
if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr():
return none(RelayMessage)
if r2.get() and
(pbSrc.getField(1, src.peerId).isErr() or
pbSrc.getRepeatedField(2, src.addrs).isErr()):
return none(RelayMessage)
if r3.get() and
(pbDst.getField(1, dst.peerId).isErr() or
pbDst.getRepeatedField(2, dst.addrs).isErr()):
return none(RelayMessage)
if r1.get():
if msgTypeOrd.int notin RelayType:
return Opt.none(RelayMessage)
rMsg.msgType = Opt.some(RelayType(msgTypeOrd))
if ?pb.getField(2, pbSrc).toOpt():
discard ?pbSrc.getField(1, src.peerId).toOpt()
discard ?pbSrc.getRepeatedField(2, src.addrs).toOpt()
rMsg.srcPeer = Opt.some(src)
if ?pb.getField(3, pbDst).toOpt():
discard ?pbDst.getField(1, dst.peerId).toOpt()
discard ?pbDst.getRepeatedField(2, dst.addrs).toOpt()
rMsg.dstPeer = Opt.some(dst)
if ?pb.getField(4, statusOrd).toOpt():
var status: StatusV1
if not checkedEnumAssign(status, statusOrd):
return Opt.none(RelayMessage)
rMsg.status = Opt.some(status)
Opt.some(rMsg)
return none(RelayMessage)
rMsg.msgType = some(RelayType(msgTypeOrd))
if r2.get(): rMsg.srcPeer = some(src)
if r3.get(): rMsg.dstPeer = some(dst)
if r4.get():
if statusOrd.int notin StatusV1:
return none(RelayMessage)
rMsg.status = some(StatusV1(statusOrd))
some(rMsg)
# Voucher
type Voucher* = object
relayPeerId*: PeerId # peer ID of the relay
reservingPeerId*: PeerId # peer ID of the reserving peer
expiration*: uint64 # UNIX UTC expiration time for the reservation
type
Voucher* = object
relayPeerId*: PeerId # peer ID of the relay
reservingPeerId*: PeerId # peer ID of the reserving peer
expiration*: uint64 # UNIX UTC expiration time for the reservation
proc decode*(_: typedesc[Voucher], buf: seq[byte]): Result[Voucher, ProtoError] =
let pb = initProtoBuffer(buf)
var v = Voucher()
?pb.getRequiredField(1, v.relayPeerId)
?pb.getRequiredField(2, v.reservingPeerId)
?pb.getRequiredField(3, v.expiration)
? pb.getRequiredField(1, v.relayPeerId)
? pb.getRequiredField(2, v.reservingPeerId)
? pb.getRequiredField(3, v.expiration)
ok(v)
@@ -135,23 +148,20 @@ proc encode*(v: Voucher): seq[byte] =
pb.finish()
pb.buffer
proc init*(
T: typedesc[Voucher],
relayPeerId: PeerId,
reservingPeerId: PeerId,
expiration: uint64,
): T =
proc init*(T: typedesc[Voucher],
relayPeerId: PeerId,
reservingPeerId: PeerId,
expiration: uint64): T =
T(
relayPeerId = relayPeerId, reservingPeerId = reservingPeerId, expiration: expiration
relayPeerId = relayPeerId,
reservingPeerId = reservingPeerId,
expiration: expiration
)
type SignedVoucher* = SignedPayload[Voucher]
proc payloadDomain*(_: typedesc[Voucher]): string =
"libp2p-relay-rsvp"
proc payloadType*(_: typedesc[Voucher]): seq[byte] =
@[(byte) 0x03, (byte) 0x02]
proc payloadDomain*(_: typedesc[Voucher]): string = "libp2p-relay-rsvp"
proc payloadType*(_: typedesc[Voucher]): seq[byte] = @[ (byte)0x03, (byte)0x02 ]
proc checkValid*(spr: SignedVoucher): Result[void, EnvelopeError] =
if not spr.data.relayPeerId.match(spr.envelope.publicKey):
@@ -165,15 +175,13 @@ type
Peer* = object
peerId*: PeerId
addrs*: seq[MultiAddress]
Reservation* = object
expire*: uint64 # required, Unix expiration time (UTC)
addrs*: seq[MultiAddress] # relay address for reserving peer
svoucher*: Opt[seq[byte]] # optional, reservation voucher
expire*: uint64 # required, Unix expiration time (UTC)
addrs*: seq[MultiAddress] # relay address for reserving peer
svoucher*: Option[seq[byte]] # optional, reservation voucher
Limit* = object
duration*: uint32 # seconds
data*: uint64 # bytes
duration*: uint32 # seconds
data*: uint64 # bytes
StatusV2* = enum
Ok = 100
@@ -184,92 +192,103 @@ type
NoReservation = 204
MalformedMessage = 400
UnexpectedMessage = 401
HopMessageType* {.pure.} = enum
Reserve = 0
Connect = 1
Status = 2
HopMessage* = object
msgType*: HopMessageType
peer*: Opt[Peer]
reservation*: Opt[Reservation]
peer*: Option[Peer]
reservation*: Option[Reservation]
limit*: Limit
status*: Opt[StatusV2]
status*: Option[StatusV2]
proc encode*(msg: HopMessage): ProtoBuffer =
var pb = initProtoBuffer()
pb.write(1, msg.msgType.ord.uint)
msg.peer.withValue(peer):
if msg.peer.isSome():
var ppb = initProtoBuffer()
ppb.write(1, peer.peerId)
for ma in peer.addrs:
ppb.write(1, msg.peer.get().peerId)
for ma in msg.peer.get().addrs:
ppb.write(2, ma.data.buffer)
ppb.finish()
pb.write(2, ppb.buffer)
msg.reservation.withValue(rsrv):
if msg.reservation.isSome():
let rsrv = msg.reservation.get()
var rpb = initProtoBuffer()
rpb.write(1, rsrv.expire)
for ma in rsrv.addrs:
rpb.write(2, ma.data.buffer)
rsrv.svoucher.withValue(vouch):
rpb.write(3, vouch)
if rsrv.svoucher.isSome():
rpb.write(3, rsrv.svoucher.get())
rpb.finish()
pb.write(3, rpb.buffer)
if msg.limit.duration > 0 or msg.limit.data > 0:
var lpb = initProtoBuffer()
if msg.limit.duration > 0:
lpb.write(1, msg.limit.duration)
if msg.limit.data > 0:
lpb.write(2, msg.limit.data)
if msg.limit.duration > 0: lpb.write(1, msg.limit.duration)
if msg.limit.data > 0: lpb.write(2, msg.limit.data)
lpb.finish()
pb.write(4, lpb.buffer)
msg.status.withValue(status):
pb.write(5, status.ord.uint)
if msg.status.isSome():
pb.write(5, msg.status.get().ord.uint)
pb.finish()
pb
proc decode*(_: typedesc[HopMessage], buf: seq[byte]): Opt[HopMessage] =
var msg: HopMessage
let pb = initProtoBuffer(buf)
proc decode*(_: typedesc[HopMessage], buf: seq[byte]): Option[HopMessage] =
var
msg: HopMessage
msgTypeOrd: uint32
pbPeer: ProtoBuffer
pbReservation: ProtoBuffer
pbLimit: ProtoBuffer
statusOrd: uint32
peer: Peer
reservation: Reservation
limit: Limit
res: bool
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, msgTypeOrd)
r2 = pb.getField(2, pbPeer)
r3 = pb.getField(3, pbReservation)
r4 = pb.getField(4, pbLimit)
r5 = pb.getField(5, statusOrd)
if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr() or r5.isErr():
return none(HopMessage)
if r2.get() and
(pbPeer.getRequiredField(1, peer.peerId).isErr() or
pbPeer.getRepeatedField(2, peer.addrs).isErr()):
return none(HopMessage)
if r3.get():
var svoucher: seq[byte]
let rSVoucher = pbReservation.getField(3, svoucher)
if pbReservation.getRequiredField(1, reservation.expire).isErr() or
pbReservation.getRepeatedField(2, reservation.addrs).isErr() or
rSVoucher.isErr():
return none(HopMessage)
if rSVoucher.get(): reservation.svoucher = some(svoucher)
if r4.get() and
(pbLimit.getField(1, limit.duration).isErr() or
pbLimit.getField(2, limit.data).isErr()):
return none(HopMessage)
var msgTypeOrd: uint32
?pb.getRequiredField(1, msgTypeOrd).toOpt()
if not checkedEnumAssign(msg.msgType, msgTypeOrd):
return Opt.none(HopMessage)
var pbPeer: ProtoBuffer
if ?pb.getField(2, pbPeer).toOpt():
var peer: Peer
?pbPeer.getRequiredField(1, peer.peerId).toOpt()
discard ?pbPeer.getRepeatedField(2, peer.addrs).toOpt()
msg.peer = Opt.some(peer)
var pbReservation: ProtoBuffer
if ?pb.getField(3, pbReservation).toOpt():
var
svoucher: seq[byte]
reservation: Reservation
if ?pbReservation.getField(3, svoucher).toOpt():
reservation.svoucher = Opt.some(svoucher)
?pbReservation.getRequiredField(1, reservation.expire).toOpt()
discard ?pbReservation.getRepeatedField(2, reservation.addrs).toOpt()
msg.reservation = Opt.some(reservation)
var pbLimit: ProtoBuffer
if ?pb.getField(4, pbLimit).toOpt():
discard ?pbLimit.getField(1, msg.limit.duration).toOpt()
discard ?pbLimit.getField(2, msg.limit.data).toOpt()
var statusOrd: uint32
if ?pb.getField(5, statusOrd).toOpt():
var status: StatusV2
if not checkedEnumAssign(status, statusOrd):
return Opt.none(HopMessage)
msg.status = Opt.some(status)
Opt.some(msg)
return none(HopMessage)
if r2.get(): msg.peer = some(peer)
if r3.get(): msg.reservation = some(reservation)
if r4.get(): msg.limit = limit
if r5.get():
if statusOrd.int notin StatusV2:
return none(HopMessage)
msg.status = some(StatusV2(statusOrd))
some(msg)
# Circuit Relay V2 Stop Message
@@ -277,65 +296,75 @@ type
StopMessageType* {.pure.} = enum
Connect = 0
Status = 1
StopMessage* = object
msgType*: StopMessageType
peer*: Opt[Peer]
peer*: Option[Peer]
limit*: Limit
status*: Opt[StatusV2]
status*: Option[StatusV2]
proc encode*(msg: StopMessage): ProtoBuffer =
var pb = initProtoBuffer()
pb.write(1, msg.msgType.ord.uint)
msg.peer.withValue(peer):
if msg.peer.isSome():
var ppb = initProtoBuffer()
ppb.write(1, peer.peerId)
for ma in peer.addrs:
ppb.write(1, msg.peer.get().peerId)
for ma in msg.peer.get().addrs:
ppb.write(2, ma.data.buffer)
ppb.finish()
pb.write(2, ppb.buffer)
if msg.limit.duration > 0 or msg.limit.data > 0:
var lpb = initProtoBuffer()
if msg.limit.duration > 0:
lpb.write(1, msg.limit.duration)
if msg.limit.data > 0:
lpb.write(2, msg.limit.data)
if msg.limit.duration > 0: lpb.write(1, msg.limit.duration)
if msg.limit.data > 0: lpb.write(2, msg.limit.data)
lpb.finish()
pb.write(3, lpb.buffer)
msg.status.withValue(status):
pb.write(4, status.ord.uint)
if msg.status.isSome():
pb.write(4, msg.status.get().ord.uint)
pb.finish()
pb
proc decode*(_: typedesc[StopMessage], buf: seq[byte]): Opt[StopMessage] =
var msg: StopMessage
proc decode*(_: typedesc[StopMessage], buf: seq[byte]): Option[StopMessage] =
var
msg: StopMessage
msgTypeOrd: uint32
pbPeer: ProtoBuffer
pbLimit: ProtoBuffer
statusOrd: uint32
peer: Peer
limit: Limit
rVoucher: ProtoResult[bool]
res: bool
let pb = initProtoBuffer(buf)
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, msgTypeOrd)
r2 = pb.getField(2, pbPeer)
r3 = pb.getField(3, pbLimit)
r4 = pb.getField(4, statusOrd)
var msgTypeOrd: uint32
?pb.getRequiredField(1, msgTypeOrd).toOpt()
if msgTypeOrd.int notin StopMessageType:
return Opt.none(StopMessage)
if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr():
return none(StopMessage)
if r2.get() and
(pbPeer.getRequiredField(1, peer.peerId).isErr() or
pbPeer.getRepeatedField(2, peer.addrs).isErr()):
return none(StopMessage)
if r3.get() and
(pbLimit.getField(1, limit.duration).isErr() or
pbLimit.getField(2, limit.data).isErr()):
return none(StopMessage)
if msgTypeOrd.int notin StopMessageType.low.ord .. StopMessageType.high.ord:
return none(StopMessage)
msg.msgType = StopMessageType(msgTypeOrd)
var pbPeer: ProtoBuffer
if ?pb.getField(2, pbPeer).toOpt():
var peer: Peer
?pbPeer.getRequiredField(1, peer.peerId).toOpt()
discard ?pbPeer.getRepeatedField(2, peer.addrs).toOpt()
msg.peer = Opt.some(peer)
var pbLimit: ProtoBuffer
if ?pb.getField(3, pbLimit).toOpt():
discard ?pbLimit.getField(1, msg.limit.duration).toOpt()
discard ?pbLimit.getField(2, msg.limit.data).toOpt()
var statusOrd: uint32
if ?pb.getField(4, statusOrd).toOpt():
var status: StatusV2
if not checkedEnumAssign(status, statusOrd):
return Opt.none(StopMessage)
msg.status = Opt.some(status)
Opt.some(msg)
if r2.get(): msg.peer = some(peer)
if r3.get(): msg.limit = limit
if r4.get():
if statusOrd.int notin StatusV2:
return none(StopMessage)
msg.status = some(StatusV2(statusOrd))
some(msg)

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,27 +7,30 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos
import ../../../stream/connection
type RelayConnection* = ref object of Connection
conn*: Connection
limitDuration*: uint32
limitData*: uint64
dataSent*: uint64
type
RelayConnection* = ref object of Connection
conn*: Connection
limitDuration*: uint32
limitData*: uint64
dataSent*: uint64
method readOnce*(
self: RelayConnection, pbytes: pointer, nbytes: int
): Future[int] {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
self: RelayConnection,
pbytes: pointer,
nbytes: int): Future[int] {.async.} =
self.activity = true
self.conn.readOnce(pbytes, nbytes)
return await self.conn.readOnce(pbytes, nbytes)
method write*(
self: RelayConnection, msg: seq[byte]
): Future[void] {.async: (raises: [CancelledError, LPStreamError]).} =
method write*(self: RelayConnection, msg: seq[byte]): Future[void] {.async.} =
self.dataSent.inc(msg.len)
if self.limitData != 0 and self.dataSent > self.limitData:
await self.close()
@@ -35,28 +38,24 @@ method write*(
self.activity = true
await self.conn.write(msg)
method closeImpl*(self: RelayConnection): Future[void] {.async: (raises: []).} =
method closeImpl*(self: RelayConnection): Future[void] {.async.} =
await self.conn.close()
await procCall Connection(self).closeImpl()
method getWrapped*(self: RelayConnection): Connection =
self.conn
method getWrapped*(self: RelayConnection): Connection = self.conn
proc new*(
T: typedesc[RelayConnection],
conn: Connection,
limitDuration: uint32,
limitData: uint64,
): T =
T: typedesc[RelayConnection],
conn: Connection,
limitDuration: uint32,
limitData: uint64): T =
let rc = T(conn: conn, limitDuration: limitDuration, limitData: limitData)
rc.dir = conn.dir
rc.initStream()
if limitDuration > 0:
proc checkDurationConnection() {.async: (raises: []).} =
try:
await noCancel conn.join().wait(limitDuration.seconds())
except AsyncTimeoutError:
await conn.close()
proc checkDurationConnection() {.async.} =
let sleep = sleepAsync(limitDuration.seconds())
await sleep or conn.join()
if sleep.finished: await conn.close()
else: sleep.cancel()
asyncSpawn checkDurationConnection()
return rc

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,25 +7,28 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import sequtils, tables
import options, sequtils, tables, sugar
import chronos, chronicles
import
./messages,
./rconn,
./utils,
../../../peerinfo,
../../../switch,
../../../multiaddress,
../../../multicodec,
../../../stream/connection,
../../../protocols/protocol,
../../../errors,
../../../utils/heartbeat,
../../../signed_envelope
import ./messages,
./rconn,
./utils,
../../../peerinfo,
../../../switch,
../../../multiaddress,
../../../multicodec,
../../../stream/connection,
../../../protocols/protocol,
../../../transports/transport,
../../../errors,
../../../utils/heartbeat,
../../../signed_envelope
# TODO:
# * Eventually replace std/times by chronos/timer. Currently chronos/timer
@@ -55,56 +58,50 @@ type
# Relay Side
type Relay* = ref object of LPProtocol
switch*: Switch
peerCount: CountTable[PeerId]
type
Relay* = ref object of LPProtocol
switch*: Switch
peerCount: CountTable[PeerId]
# number of reservation (relayv2) + number of connection (relayv1)
maxCircuit*: int
# number of reservation (relayv2) + number of connection (relayv1)
maxCircuit*: int
maxCircuitPerPeer*: int
msgSize*: int
# RelayV1
isCircuitRelayV1*: bool
streamCount: int
# RelayV2
rsvp: Table[PeerId, DateTime]
reservationLoop: Future[void]
reservationTTL*: times.Duration
heartbeatSleepTime*: uint32
limit*: Limit
maxCircuitPerPeer*: int
msgSize*: int
# RelayV1
isCircuitRelayV1*: bool
streamCount: int
# RelayV2
rsvp: Table[PeerId, DateTime]
reservationLoop: Future[void]
reservationTTL*: times.Duration
heartbeatSleepTime*: uint32
limit*: Limit
# Relay V2
proc createReserveResponse(
r: Relay, pid: PeerId, expire: DateTime
): Result[HopMessage, CryptoError] =
r: Relay,
pid: PeerId,
expire: DateTime): Result[HopMessage, CryptoError] =
let
expireUnix = expire.toTime.toUnix.uint64
v = Voucher(
relayPeerId: r.switch.peerInfo.peerId,
reservingPeerId: pid,
expiration: expireUnix,
)
sv = ?SignedVoucher.init(r.switch.peerInfo.privateKey, v)
ma =
?MultiAddress.init("/p2p/" & $r.switch.peerInfo.peerId).orErr(
CryptoError.KeyError
)
rsrv = Reservation(
expire: expireUnix,
addrs: r.switch.peerInfo.addrs.mapIt(?it.concat(ma).orErr(CryptoError.KeyError)),
svoucher: Opt.some(?sv.encode),
)
msg = HopMessage(
msgType: HopMessageType.Status,
reservation: Opt.some(rsrv),
limit: r.limit,
status: Opt.some(Ok),
)
v = Voucher(relayPeerId: r.switch.peerInfo.peerId,
reservingPeerId: pid,
expiration: expireUnix)
sv = ? SignedVoucher.init(r.switch.peerInfo.privateKey, v)
ma = ? MultiAddress.init("/p2p/" & $r.switch.peerInfo.peerId).orErr(CryptoError.KeyError)
rsrv = Reservation(expire: expireUnix,
addrs: r.switch.peerInfo.addrs.mapIt(
? it.concat(ma).orErr(CryptoError.KeyError)),
svoucher: some(? sv.encode))
msg = HopMessage(msgType: HopMessageType.Status,
reservation: some(rsrv),
limit: r.limit,
status: some(Ok))
return ok(msg)
proc isRelayed*(conn: Connection): bool =
proc isRelayed(conn: Connection): bool =
var wrappedConn = conn
while not isNil(wrappedConn):
if wrappedConn of RelayConnection:
@@ -112,7 +109,7 @@ proc isRelayed*(conn: Connection): bool =
wrappedConn = wrappedConn.getWrapped()
return false
proc handleReserve(r: Relay, conn: Connection) {.async.} =
proc handleReserve(r: Relay, conn: Connection) {.async, gcsafe.} =
if conn.isRelayed():
trace "reservation attempt over relay connection", pid = conn.peerId
await sendHopStatus(conn, PermissionDenied)
@@ -122,28 +119,32 @@ proc handleReserve(r: Relay, conn: Connection) {.async.} =
trace "Too many reservations", pid = conn.peerId
await sendHopStatus(conn, ReservationRefused)
return
trace "reserving relay slot for", pid = conn.peerId
let
pid = conn.peerId
expire = now().utc + r.reservationTTL
msg = r.createReserveResponse(pid, expire).valueOr:
trace "error signing the voucher", pid
return
msg = r.createReserveResponse(pid, expire)
trace "reserving relay slot for", pid
if msg.isErr():
trace "error signing the voucher", error = error(msg), pid
return
r.rsvp[pid] = expire
await conn.writeLp(encode(msg).buffer)
await conn.writeLp(encode(msg.get()).buffer)
proc handleConnect(r: Relay, connSrc: Connection, msg: HopMessage) {.async.} =
proc handleConnect(r: Relay,
connSrc: Connection,
msg: HopMessage) {.async, gcsafe.} =
if connSrc.isRelayed():
trace "connection attempt over relay connection"
await sendHopStatus(connSrc, PermissionDenied)
return
if msg.peer.isNone():
await sendHopStatus(connSrc, MalformedMessage)
return
let
msgPeer = msg.peer.valueOr:
await sendHopStatus(connSrc, MalformedMessage)
return
src = connSrc.peerId
dst = msgPeer.peerId
dst = msg.peer.get().peerId
if dst notin r.rsvp:
trace "refusing connection, no reservation", src, dst
await sendHopStatus(connSrc, NoReservation)
@@ -155,42 +156,37 @@ proc handleConnect(r: Relay, connSrc: Connection, msg: HopMessage) {.async.} =
r.peerCount.inc(src, -1)
r.peerCount.inc(dst, -1)
if r.peerCount[src] > r.maxCircuitPerPeer or r.peerCount[dst] > r.maxCircuitPerPeer:
trace "too many connections",
src = r.peerCount[src], dst = r.peerCount[dst], max = r.maxCircuitPerPeer
if r.peerCount[src] > r.maxCircuitPerPeer or
r.peerCount[dst] > r.maxCircuitPerPeer:
trace "too many connections", src = r.peerCount[src],
dst = r.peerCount[dst],
max = r.maxCircuitPerPeer
await sendHopStatus(connSrc, ResourceLimitExceeded)
return
let connDst =
try:
await r.switch.dial(dst, RelayV2StopCodec)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error opening relay stream", dst, exc = exc.msg
await sendHopStatus(connSrc, ConnectionFailed)
return
let connDst = try:
await r.switch.dial(dst, RelayV2StopCodec)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error opening relay stream", dst, exc=exc.msg
await sendHopStatus(connSrc, ConnectionFailed)
return
defer:
await connDst.close()
proc sendStopMsg() {.async.} =
let stopMsg = StopMessage(
msgType: StopMessageType.Connect,
peer: Opt.some(Peer(peerId: src, addrs: @[])),
limit: r.limit,
)
let stopMsg = StopMessage(msgType: StopMessageType.Connect,
peer: some(Peer(peerId: src, addrs: @[])),
limit: r.limit)
await connDst.writeLp(encode(stopMsg).buffer)
let msg = StopMessage.decode(await connDst.readLp(r.msgSize)).valueOr:
raise newException(SendStopError, "Malformed message")
let msg = StopMessage.decode(await connDst.readLp(r.msgSize)).get()
if msg.msgType != StopMessageType.Status:
raise
newException(SendStopError, "Unexpected stop response, not a status message")
raise newException(SendStopError, "Unexpected stop response, not a status message")
if msg.status.get(UnexpectedMessage) != Ok:
raise newException(SendStopError, "Relay stop failure")
await connSrc.writeLp(
encode(HopMessage(msgType: HopMessageType.Status, status: Opt.some(Ok))).buffer
)
await connSrc.writeLp(encode(HopMessage(msgType: HopMessageType.Status,
status: some(Ok))).buffer)
try:
await sendStopMsg()
except CancelledError as exc:
@@ -209,54 +205,55 @@ proc handleConnect(r: Relay, connSrc: Connection, msg: HopMessage) {.async.} =
await rconnDst.close()
await bridge(rconnSrc, rconnDst)
proc handleHopStreamV2*(r: Relay, conn: Connection) {.async.} =
let msg = HopMessage.decode(await conn.readLp(r.msgSize)).valueOr:
proc handleHopStreamV2*(r: Relay, conn: Connection) {.async, gcsafe.} =
let msgOpt = HopMessage.decode(await conn.readLp(r.msgSize))
if msgOpt.isNone():
await sendHopStatus(conn, MalformedMessage)
return
trace "relayv2 handle stream", msg = msg
case msg.msgType
of HopMessageType.Reserve:
await r.handleReserve(conn)
of HopMessageType.Connect:
await r.handleConnect(conn, msg)
trace "relayv2 handle stream", msg = msgOpt.get()
let msg = msgOpt.get()
case msg.msgType:
of HopMessageType.Reserve: await r.handleReserve(conn)
of HopMessageType.Connect: await r.handleConnect(conn, msg)
else:
trace "Unexpected relayv2 handshake", msgType = msg.msgType
trace "Unexpected relayv2 handshake", msgType=msg.msgType
await sendHopStatus(conn, MalformedMessage)
# Relay V1
proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async.} =
proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async, gcsafe.} =
r.streamCount.inc()
defer:
r.streamCount.dec()
defer: r.streamCount.dec()
if r.streamCount + r.rsvp.len() >= r.maxCircuit:
trace "refusing connection; too many active circuit",
streamCount = r.streamCount, rsvp = r.rsvp.len()
trace "refusing connection; too many active circuit", streamCount = r.streamCount, rsvp = r.rsvp.len()
await sendStatus(connSrc, StatusV1.HopCantSpeakRelay)
return
var src, dst: RelayPeer
proc checkMsg(): Result[RelayMessage, StatusV1] =
src = msg.srcPeer.valueOr:
if msg.srcPeer.isNone:
return err(StatusV1.HopSrcMultiaddrInvalid)
let src = msg.srcPeer.get()
if src.peerId != connSrc.peerId:
return err(StatusV1.HopSrcMultiaddrInvalid)
dst = msg.dstPeer.valueOr:
if msg.dstPeer.isNone:
return err(StatusV1.HopDstMultiaddrInvalid)
let dst = msg.dstPeer.get()
if dst.peerId == r.switch.peerInfo.peerId:
return err(StatusV1.HopCantRelayToSelf)
if not r.switch.isConnected(dst.peerId):
trace "relay not connected to dst", dst
return err(StatusV1.HopNoConnToDst)
ok(msg)
let check = checkMsg()
if check.isErr:
await sendStatus(connSrc, check.error())
return
let
src = msg.srcPeer.get()
dst = msg.dstPeer.get()
if r.peerCount[src.peerId] >= r.maxCircuitPerPeer or
r.peerCount[dst.peerId] >= r.maxCircuitPerPeer:
r.peerCount[dst.peerId] >= r.maxCircuitPerPeer:
trace "refusing connection; too many connection from src or to dst", src, dst
await sendStatus(connSrc, StatusV1.HopCantSpeakRelay)
return
@@ -266,40 +263,40 @@ proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async.} =
r.peerCount.inc(src.peerId, -1)
r.peerCount.inc(dst.peerId, -1)
let connDst =
try:
await r.switch.dial(dst.peerId, RelayV1Codec)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error opening relay stream", dst, exc = exc.msg
await sendStatus(connSrc, StatusV1.HopCantDialDst)
return
let connDst = try:
await r.switch.dial(dst.peerId, RelayV1Codec)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error opening relay stream", dst, exc=exc.msg
await sendStatus(connSrc, StatusV1.HopCantDialDst)
return
defer:
await connDst.close()
let msgToSend = RelayMessage(
msgType: Opt.some(RelayType.Stop), srcPeer: Opt.some(src), dstPeer: Opt.some(dst)
)
msgType: some(RelayType.Stop),
srcPeer: some(src),
dstPeer: some(dst))
let msgRcvFromDstOpt =
try:
await connDst.writeLp(encode(msgToSend).buffer)
RelayMessage.decode(await connDst.readLp(r.msgSize))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error writing stop handshake or reading stop response", exc = exc.msg
await sendStatus(connSrc, StatusV1.HopCantOpenDstStream)
return
let msgRcvFromDstOpt = try:
await connDst.writeLp(encode(msgToSend).buffer)
RelayMessage.decode(await connDst.readLp(r.msgSize))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error writing stop handshake or reading stop response", exc=exc.msg
await sendStatus(connSrc, StatusV1.HopCantOpenDstStream)
return
let msgRcvFromDst = msgRcvFromDstOpt.valueOr:
if msgRcvFromDstOpt.isNone:
trace "error reading stop response", msg = msgRcvFromDstOpt
await sendStatus(connSrc, StatusV1.HopCantOpenDstStream)
return
let msgRcvFromDst = msgRcvFromDstOpt.get()
if msgRcvFromDst.msgType.get(RelayType.Stop) != RelayType.Status or
msgRcvFromDst.status.get(StatusV1.StopRelayRefused) != StatusV1.Success:
msgRcvFromDst.status.get(StatusV1.StopRelayRefused) != StatusV1.Success:
trace "unexcepted relay stop response", msgRcvFromDst
await sendStatus(connSrc, StatusV1.HopCantOpenDstStream)
return
@@ -308,64 +305,51 @@ proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async.} =
trace "relaying connection", src, dst
await bridge(connSrc, connDst)
proc handleStreamV1(r: Relay, conn: Connection) {.async.} =
let msg = RelayMessage.decode(await conn.readLp(r.msgSize)).valueOr:
proc handleStreamV1(r: Relay, conn: Connection) {.async, gcsafe.} =
let msgOpt = RelayMessage.decode(await conn.readLp(r.msgSize))
if msgOpt.isNone:
await sendStatus(conn, StatusV1.MalformedMessage)
return
trace "relay handle stream", msg
let typ = msg.msgType.valueOr:
trace "Message type not set"
await sendStatus(conn, StatusV1.MalformedMessage)
return
case typ
of RelayType.Hop:
await r.handleHop(conn, msg)
of RelayType.Stop:
await sendStatus(conn, StatusV1.StopRelayRefused)
of RelayType.CanHop:
await sendStatus(conn, StatusV1.Success)
else:
trace "Unexpected relay handshake", msgType = msg.msgType
await sendStatus(conn, StatusV1.MalformedMessage)
trace "relay handle stream", msg = msgOpt.get()
let msg = msgOpt.get()
case msg.msgType.get:
of RelayType.Hop: await r.handleHop(conn, msg)
of RelayType.Stop: await sendStatus(conn, StatusV1.StopRelayRefused)
of RelayType.CanHop: await sendStatus(conn, StatusV1.Success)
else:
trace "Unexpected relay handshake", msgType=msg.msgType
await sendStatus(conn, StatusV1.MalformedMessage)
proc setup*(r: Relay, switch: Switch) =
r.switch = switch
r.switch.addPeerEventHandler(
proc(peerId: PeerId, event: PeerEvent) {.async.} =
r.rsvp.del(peerId)
,
Left,
)
proc (peerId: PeerId, event: PeerEvent) {.async.} =
r.rsvp.del(peerId),
Left)
proc new*(
T: typedesc[Relay],
reservationTTL: times.Duration = DefaultReservationTTL,
limitDuration: uint32 = DefaultLimitDuration,
limitData: uint64 = DefaultLimitData,
heartbeatSleepTime: uint32 = DefaultHeartbeatSleepTime,
maxCircuit: int = MaxCircuit,
maxCircuitPerPeer: int = MaxCircuitPerPeer,
msgSize: int = RelayMsgSize,
circuitRelayV1: bool = false,
): T =
let r = T(
reservationTTL: reservationTTL,
limit: Limit(duration: limitDuration, data: limitData),
heartbeatSleepTime: heartbeatSleepTime,
maxCircuit: maxCircuit,
maxCircuitPerPeer: maxCircuitPerPeer,
msgSize: msgSize,
isCircuitRelayV1: circuitRelayV1,
)
proc new*(T: typedesc[Relay],
reservationTTL: times.Duration = DefaultReservationTTL,
limitDuration: uint32 = DefaultLimitDuration,
limitData: uint64 = DefaultLimitData,
heartbeatSleepTime: uint32 = DefaultHeartbeatSleepTime,
maxCircuit: int = MaxCircuit,
maxCircuitPerPeer: int = MaxCircuitPerPeer,
msgSize: int = RelayMsgSize,
circuitRelayV1: bool = false): T =
proc handleStream(conn: Connection, proto: string) {.async.} =
let r = T(reservationTTL: reservationTTL,
limit: Limit(duration: limitDuration, data: limitData),
heartbeatSleepTime: heartbeatSleepTime,
maxCircuit: maxCircuit,
maxCircuitPerPeer: maxCircuitPerPeer,
msgSize: msgSize,
isCircuitRelayV1: circuitRelayV1)
proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} =
try:
case proto
of RelayV2HopCodec:
await r.handleHopStreamV2(conn)
of RelayV1Codec:
await r.handleStreamV1(conn)
case proto:
of RelayV2HopCodec: await r.handleHopStreamV2(conn)
of RelayV1Codec: await r.handleStreamV1(conn)
except CancelledError as exc:
raise exc
except CatchableError as exc:
@@ -374,11 +358,8 @@ proc new*(
trace "exiting relayv2 handler", conn
await conn.close()
r.codecs =
if r.isCircuitRelayV1:
@[RelayV1Codec]
else:
@[RelayV2HopCodec, RelayV1Codec]
r.codecs = if r.isCircuitRelayV1: @[RelayV1Codec]
else: @[RelayV2HopCodec, RelayV1Codec]
r.handler = handleStream
r
@@ -389,23 +370,17 @@ proc deletesReservation(r: Relay) {.async.} =
if n > r.rsvp[k]:
r.rsvp.del(k)
method start*(r: Relay): Future[void] {.async: (raises: [CancelledError], raw: true).} =
let fut = newFuture[void]()
fut.complete()
method start*(r: Relay) {.async.} =
if not r.reservationLoop.isNil:
warn "Starting relay twice"
return fut
return
r.reservationLoop = r.deletesReservation()
r.started = true
fut
method stop*(r: Relay): Future[void] {.async: (raises: [], raw: true).} =
let fut = newFuture[void]()
fut.complete()
method stop*(r: Relay) {.async.} =
if r.reservationLoop.isNil:
warn "Stopping relay without starting it"
return fut
return
r.started = false
r.reservationLoop.cancel()
r.reservationLoop = nil
fut

View File

@@ -7,27 +7,30 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import sequtils, strutils
import chronos, chronicles
import
./client,
./rconn,
./utils,
../../../switch,
../../../stream/connection,
../../../transports/transport
import ./client,
./rconn,
./utils,
../../../switch,
../../../stream/connection,
../../../transports/transport
logScope:
topics = "libp2p relay relay-transport"
type RelayTransport* = ref object of Transport
client*: RelayClient
queue: AsyncQueue[Connection]
selfRunning: bool
type
RelayTransport* = ref object of Transport
client*: RelayClient
queue: AsyncQueue[Connection]
selfRunning: bool
method start*(self: RelayTransport, ma: seq[MultiAddress]) {.async.} =
if self.selfRunning:
@@ -35,44 +38,46 @@ method start*(self: RelayTransport, ma: seq[MultiAddress]) {.async.} =
return
self.client.onNewConnection = proc(
conn: Connection, duration: uint32 = 0, data: uint64 = 0
) {.async.} =
await self.queue.addLast(RelayConnection.new(conn, duration, data))
await conn.join()
conn: Connection,
duration: uint32 = 0,
data: uint64 = 0) {.async, gcsafe, raises: [Defect].} =
await self.queue.addLast(RelayConnection.new(conn, duration, data))
await conn.join()
self.selfRunning = true
await procCall Transport(self).start(ma)
trace "Starting Relay transport"
method stop*(self: RelayTransport) {.async.} =
method stop*(self: RelayTransport) {.async, gcsafe.} =
self.running = false
self.selfRunning = false
self.client.onNewConnection = nil
while not self.queue.empty():
await self.queue.popFirstNoWait().close()
method accept*(self: RelayTransport): Future[Connection] {.async.} =
method accept*(self: RelayTransport): Future[Connection] {.async, gcsafe.} =
result = await self.queue.popFirst()
proc dial*(self: RelayTransport, ma: MultiAddress): Future[Connection] {.async.} =
proc dial*(self: RelayTransport, ma: MultiAddress): Future[Connection] {.async, gcsafe.} =
let
sma = toSeq(ma.items())
relayAddrs = sma[0 .. sma.len - 4].mapIt(it.tryGet()).foldl(a & b)
relayAddrs = sma[0..sma.len-4].mapIt(it.tryGet()).foldl(a & b)
var
relayPeerId: PeerId
dstPeerId: PeerId
if not relayPeerId.init(($(sma[^3].tryGet())).split('/')[2]):
if not relayPeerId.init(($(sma[^3].get())).split('/')[2]):
raise newException(RelayV2DialError, "Relay doesn't exist")
if not dstPeerId.init(($(sma[^1].tryGet())).split('/')[2]):
if not dstPeerId.init(($(sma[^1].get())).split('/')[2]):
raise newException(RelayV2DialError, "Destination doesn't exist")
trace "Dial", relayPeerId, dstPeerId
let conn = await self.client.switch.dial(
relayPeerId, @[relayAddrs], @[RelayV2HopCodec, RelayV1Codec]
)
relayPeerId,
@[ relayAddrs ],
@[ RelayV2HopCodec, RelayV1Codec ])
conn.dir = Direction.Out
var rc: RelayConnection
try:
case conn.protocol
case conn.protocol:
of RelayV1Codec:
return await self.client.dialPeerV1(conn, dstPeerId, @[])
of RelayV2HopCodec:
@@ -81,27 +86,21 @@ proc dial*(self: RelayTransport, ma: MultiAddress): Future[Connection] {.async.}
except CancelledError as exc:
raise exc
except CatchableError as exc:
if not rc.isNil:
await rc.close()
if not rc.isNil: await rc.close()
raise exc
method dial*(
self: RelayTransport,
hostname: string,
ma: MultiAddress,
peerId: Opt[PeerId] = Opt.none(PeerId),
): Future[Connection] {.async.} =
peerId.withValue(pid):
let address = MultiAddress.init($ma & "/p2p/" & $pid).tryGet()
result = await self.dial(address)
self: RelayTransport,
hostname: string,
ma: MultiAddress,
peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.async, gcsafe.} =
let address = MultiAddress.init($ma & "/p2p/" & $peerId.get()).tryGet()
result = await self.dial(address)
method handles*(self: RelayTransport, ma: MultiAddress): bool {.gcsafe.} =
try:
if ma.protocols.isOk():
let sma = toSeq(ma.items())
result = sma.len >= 2 and CircuitRelay.match(sma[^1].tryGet())
except CatchableError as exc:
result = false
method handles*(self: RelayTransport, ma: MultiAddress): bool {.gcsafe} =
if ma.protocols.isOk():
let sma = toSeq(ma.items())
result = sma.len >= 2 and CircuitRelay.match(sma[^1].get())
trace "Handles return", ma, result
proc new*(T: typedesc[RelayTransport], cl: RelayClient, upgrader: Upgrade): T =

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,10 +7,17 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import options
import chronos, chronicles
import ./messages, ../../../stream/connection
import ./messages,
../../../stream/connection
logScope:
topics = "libp2p relay relay-utils"
@@ -20,74 +27,63 @@ const
RelayV2HopCodec* = "/libp2p/circuit/relay/0.2.0/hop"
RelayV2StopCodec* = "/libp2p/circuit/relay/0.2.0/stop"
proc sendStatus*(
conn: Connection, code: StatusV1
) {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
proc sendStatus*(conn: Connection, code: StatusV1) {.async, gcsafe.} =
trace "send relay/v1 status", status = $code & "(" & $ord(code) & ")"
let
msg = RelayMessage(msgType: Opt.some(RelayType.Status), status: Opt.some(code))
msg = RelayMessage(msgType: some(RelayType.Status), status: some(code))
pb = encode(msg)
conn.writeLp(pb.buffer)
await conn.writeLp(pb.buffer)
proc sendHopStatus*(
conn: Connection, code: StatusV2
) {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
proc sendHopStatus*(conn: Connection, code: StatusV2) {.async, gcsafe.} =
trace "send hop relay/v2 status", status = $code & "(" & $ord(code) & ")"
let
msg = HopMessage(msgType: HopMessageType.Status, status: Opt.some(code))
msg = HopMessage(msgType: HopMessageType.Status, status: some(code))
pb = encode(msg)
conn.writeLp(pb.buffer)
await conn.writeLp(pb.buffer)
proc sendStopStatus*(
conn: Connection, code: StatusV2
) {.async: (raises: [CancelledError, LPStreamError], raw: true).} =
proc sendStopStatus*(conn: Connection, code: StatusV2) {.async.} =
trace "send stop relay/v2 status", status = $code & " (" & $ord(code) & ")"
let
msg = StopMessage(msgType: StopMessageType.Status, status: Opt.some(code))
msg = StopMessage(msgType: StopMessageType.Status, status: some(code))
pb = encode(msg)
conn.writeLp(pb.buffer)
await conn.writeLp(pb.buffer)
proc bridge*(
connSrc: Connection, connDst: Connection
) {.async: (raises: [CancelledError]).} =
proc bridge*(connSrc: Connection, connDst: Connection) {.async.} =
const bufferSize = 4096
var
bufSrcToDst: array[bufferSize, byte]
bufDstToSrc: array[bufferSize, byte]
futSrc = connSrc.readOnce(addr bufSrcToDst[0], bufSrcToDst.len)
futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.len)
bytesSentFromSrcToDst = 0
bytesSentFromDstToSrc = 0
futSrc = connSrc.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1)
futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1)
bytesSendFromSrcToDst = 0
bytesSendFromDstToSrc = 0
bufRead: int
try:
while not connSrc.closed() and not connDst.closed():
try: # https://github.com/status-im/nim-chronos/issues/516
discard await race(futSrc, futDst)
except ValueError:
raiseAssert("Futures list is not empty")
await futSrc or futDst
if futSrc.finished():
bufRead = await futSrc
if bufRead > 0:
bytesSentFromSrcToDst.inc(bufRead)
await connDst.write(@bufSrcToDst[0 ..< bufRead])
zeroMem(addr bufSrcToDst[0], bufSrcToDst.len)
futSrc = connSrc.readOnce(addr bufSrcToDst[0], bufSrcToDst.len)
bytesSendFromSrcToDst.inc(bufRead)
await connDst.write(@bufSrcToDst[0..<bufRead])
zeroMem(addr(bufSrcToDst), bufSrcToDst.high + 1)
futSrc = connSrc.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1)
if futDst.finished():
bufRead = await futDst
if bufRead > 0:
bytesSentFromDstToSrc += bufRead
await connSrc.write(bufDstToSrc[0 ..< bufRead])
zeroMem(addr bufDstToSrc[0], bufDstToSrc.len)
futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.len)
bytesSendFromDstToSrc += bufRead
await connSrc.write(bufDstToSrc[0..<bufRead])
zeroMem(addr(bufDstToSrc), bufDstToSrc.high + 1)
futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1)
except CancelledError as exc:
raise exc
except LPStreamError as exc:
except CatchableError as exc:
if connSrc.closed() or connSrc.atEof():
trace "relay src closed connection", src = connSrc.peerId
if connDst.closed() or connDst.atEof():
trace "relay dst closed connection", dst = connDst.peerId
trace "relay error", exc = exc.msg
trace "end relaying", bytesSentFromSrcToDst, bytesSentFromDstToSrc
trace "relay error", exc=exc.msg
trace "end relaying", bytesSendFromSrcToDst, bytesSendFromDstToSrc
await futSrc.cancelAndWait()
await futDst.cancelAndWait()

View File

@@ -10,25 +10,23 @@
## `Identify <https://docs.libp2p.io/concepts/protocols/#identify>`_ and
## `Push Identify <https://docs.libp2p.io/concepts/protocols/#identify-push>`_ implementation
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[sequtils, options, strutils, sugar]
import stew/results
import chronos, chronicles
import
../protobuf/minprotobuf,
../peerinfo,
../stream/connection,
../peerid,
../crypto/crypto,
../multiaddress,
../multicodec,
../protocols/protocol,
../utility,
../errors,
../observedaddrmanager
export observedaddrmanager
import ../protobuf/minprotobuf,
../peerinfo,
../stream/connection,
../peerid,
../crypto/crypto,
../multiaddress,
../protocols/protocol,
../utility,
../errors
logScope:
topics = "libp2p identify"
@@ -58,11 +56,12 @@ type
Identify* = ref object of LPProtocol
peerInfo*: PeerInfo
sendSignedPeerRecord*: bool
observedAddrManager*: ObservedAddrManager
IdentifyPushHandler* = proc(peer: PeerId, newInfo: IdentifyInfo): Future[void] {.
gcsafe, raises: [], public
.}
IdentifyPushHandler* = proc (
peer: PeerId,
newInfo: IdentifyInfo):
Future[void]
{.gcsafe, raises: [Defect], public.}
IdentifyPush* = ref object of LPProtocol
identifyHandler: IdentifyPushHandler
@@ -71,43 +70,48 @@ chronicles.expandIt(IdentifyInfo):
pubkey = ($it.pubkey).shortLog
addresses = it.addrs.map(x => $x).join(",")
protocols = it.protos.map(x => $x).join(",")
observable_address = $it.observedAddr
observable_address =
if it.observedAddr.isSome(): $it.observedAddr.get()
else: "None"
proto_version = it.protoVersion.get("None")
agent_version = it.agentVersion.get("None")
signedPeerRecord =
# The SPR contains the same data as the identify message
# would be cumbersome to log
if it.signedPeerRecord.isSome(): "Some" else: "None"
if iinfo.signedPeerRecord.isSome(): "Some"
else: "None"
proc encodeMsg(
peerInfo: PeerInfo, observedAddr: Opt[MultiAddress], sendSpr: bool
): ProtoBuffer {.raises: [].} =
proc encodeMsg(peerInfo: PeerInfo, observedAddr: Opt[MultiAddress], sendSpr: bool): ProtoBuffer
{.raises: [Defect].} =
result = initProtoBuffer()
let pkey = peerInfo.publicKey
result.write(1, pkey.getBytes().expect("valid key"))
result.write(1, pkey.getBytes().get())
for ma in peerInfo.addrs:
result.write(2, ma.data.buffer)
for proto in peerInfo.protocols:
result.write(3, proto)
observedAddr.withValue(observed):
result.write(4, observed.data.buffer)
if observedAddr.isSome:
result.write(4, observedAddr.get().data.buffer)
let protoVersion = ProtoVersion
result.write(5, protoVersion)
let agentVersion =
if peerInfo.agentVersion.len <= 0: AgentVersion else: peerInfo.agentVersion
let agentVersion = if peerInfo.agentVersion.len <= 0:
AgentVersion
else:
peerInfo.agentVersion
result.write(6, agentVersion)
## Optionally populate signedPeerRecord field.
## See https://github.com/libp2p/go-libp2p/blob/ddf96ce1cfa9e19564feb9bd3e8269958bbc0aba/p2p/protocol/identify/pb/identify.proto for reference.
if sendSpr:
peerInfo.signedPeerRecord.envelope.encode().toOpt().withValue(sprBuff):
result.write(8, sprBuff)
let sprBuff = peerInfo.signedPeerRecord.envelope.encode()
if sprBuff.isOk():
result.write(8, sprBuff.get())
result.finish()
proc decodeMsg*(buf: seq[byte]): Opt[IdentifyInfo] =
proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] =
var
iinfo: IdentifyInfo
pubkey: PublicKey
@@ -117,38 +121,52 @@ proc decodeMsg*(buf: seq[byte]): Opt[IdentifyInfo] =
signedPeerRecord: SignedPeerRecord
var pb = initProtoBuffer(buf)
if ?pb.getField(1, pubkey).toOpt():
iinfo.pubkey = some(pubkey)
if ?pb.getField(8, signedPeerRecord).toOpt() and
pubkey == signedPeerRecord.envelope.publicKey:
iinfo.signedPeerRecord = some(signedPeerRecord.envelope)
discard ?pb.getRepeatedField(2, iinfo.addrs).toOpt()
discard ?pb.getRepeatedField(3, iinfo.protos).toOpt()
if ?pb.getField(4, oaddr).toOpt():
iinfo.observedAddr = some(oaddr)
if ?pb.getField(5, protoVersion).toOpt():
iinfo.protoVersion = some(protoVersion)
if ?pb.getField(6, agentVersion).toOpt():
iinfo.agentVersion = some(agentVersion)
Opt.some(iinfo)
let r1 = pb.getField(1, pubkey)
let r2 = pb.getRepeatedField(2, iinfo.addrs)
let r3 = pb.getRepeatedField(3, iinfo.protos)
let r4 = pb.getField(4, oaddr)
let r5 = pb.getField(5, protoVersion)
let r6 = pb.getField(6, agentVersion)
let r8 = pb.getField(8, signedPeerRecord)
let res = r1.isOk() and r2.isOk() and r3.isOk() and
r4.isOk() and r5.isOk() and r6.isOk() and
r8.isOk()
if res:
if r1.get():
iinfo.pubkey = some(pubkey)
if r4.get():
iinfo.observedAddr = some(oaddr)
if r5.get():
iinfo.protoVersion = some(protoVersion)
if r6.get():
iinfo.agentVersion = some(agentVersion)
if r8.get() and r1.get():
if iinfo.pubkey.get() == signedPeerRecord.envelope.publicKey:
iinfo.signedPeerRecord = some(signedPeerRecord.envelope)
debug "decodeMsg: decoded identify", iinfo
some(iinfo)
else:
trace "decodeMsg: failed to decode received message"
none[IdentifyInfo]()
proc new*(
T: typedesc[Identify],
peerInfo: PeerInfo,
sendSignedPeerRecord = false,
observedAddrManager = ObservedAddrManager.new(),
): T =
T: typedesc[Identify],
peerInfo: PeerInfo,
sendSignedPeerRecord = false
): T =
let identify = T(
peerInfo: peerInfo,
sendSignedPeerRecord: sendSignedPeerRecord,
observedAddrManager: observedAddrManager,
sendSignedPeerRecord: sendSignedPeerRecord
)
identify.init()
identify
method init*(p: Identify) =
proc handle(conn: Connection, proto: string) {.async.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} =
try:
trace "handling identify request", conn
var pb = encodeMsg(p.peerInfo, conn.observedAddr, p.sendSignedPeerRecord)
@@ -164,38 +182,34 @@ method init*(p: Identify) =
p.handler = handle
p.codec = IdentifyCodec
proc identify*(
self: Identify, conn: Connection, remotePeerId: PeerId
): Future[IdentifyInfo] {.async.} =
proc identify*(p: Identify,
conn: Connection,
remotePeerId: PeerId): Future[IdentifyInfo] {.async, gcsafe.} =
trace "initiating identify", conn
var message = await conn.readLp(64 * 1024)
var message = await conn.readLp(64*1024)
if len(message) == 0:
trace "identify: Empty message received!", conn
raise newException(IdentityInvalidMsgError, "Empty message received!")
var info = decodeMsg(message).valueOr:
let infoOpt = decodeMsg(message)
if infoOpt.isNone():
raise newException(IdentityInvalidMsgError, "Incorrect message received!")
debug "identify: decoded message", conn, info
let
pubkey = info.pubkey.valueOr:
raise newException(IdentityInvalidMsgError, "No pubkey in identify")
peer = PeerId.init(pubkey).valueOr:
raise newException(IdentityInvalidMsgError, $error)
result = infoOpt.get()
if peer != remotePeerId:
trace "Peer ids don't match", remote = peer, local = remotePeerId
raise newException(IdentityNoMatchError, "Peer ids don't match")
info.peerId = peer
if result.pubkey.isSome:
let peer = PeerId.init(result.pubkey.get())
if peer.isErr:
raise newException(IdentityInvalidMsgError, $peer.error)
else:
result.peerId = peer.get()
if peer.get() != remotePeerId:
trace "Peer ids don't match",
remote = peer,
local = remotePeerId
info.observedAddr.withValue(observed):
# Currently, we use the ObservedAddrManager only to find our dialable external NAT address. Therefore, addresses
# like "...\p2p-circuit\p2p\..." and "\p2p\..." are not useful to us.
if observed.contains(multiCodec("p2p-circuit")).get(false) or
P2PPattern.matchPartial(observed):
trace "Not adding address to ObservedAddrManager.", observed
elif not self.observedAddrManager.addObservation(observed):
trace "Observed address is not valid.", observedAddr = observed
return info
raise newException(IdentityNoMatchError, "Peer ids don't match")
else:
raise newException(IdentityInvalidMsgError, "No pubkey in identify")
proc new*(T: typedesc[IdentifyPush], handler: IdentifyPushHandler = nil): T {.public.} =
## Create a IdentifyPush protocol. `handler` will be called every time
@@ -205,24 +219,26 @@ proc new*(T: typedesc[IdentifyPush], handler: IdentifyPushHandler = nil): T {.pu
identifypush
proc init*(p: IdentifyPush) =
proc handle(conn: Connection, proto: string) {.async.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} =
trace "handling identify push", conn
try:
var message = await conn.readLp(64 * 1024)
var message = await conn.readLp(64*1024)
var identInfo = decodeMsg(message).valueOr:
let infoOpt = decodeMsg(message)
if infoOpt.isNone():
raise newException(IdentityInvalidMsgError, "Incorrect message received!")
debug "identify push: decoded message", conn, identInfo
identInfo.pubkey.withValue(pubkey):
let receivedPeerId = PeerId.init(pubkey).tryGet()
var indentInfo = infoOpt.get()
if indentInfo.pubkey.isSome:
let receivedPeerId = PeerId.init(indentInfo.pubkey.get()).tryGet()
if receivedPeerId != conn.peerId:
raise newException(IdentityNoMatchError, "Peer ids don't match")
identInfo.peerId = receivedPeerId
indentInfo.peerId = receivedPeerId
trace "triggering peer event", peerInfo = conn.peerId
if not isNil(p.identifyHandler):
await p.identifyHandler(conn.peerId, identInfo)
await p.identifyHandler(conn.peerId, indentInfo)
except CancelledError as exc:
raise exc
except CatchableError as exc:

View File

@@ -1,50 +0,0 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
## `Perf <https://github.com/libp2p/specs/blob/master/perf/perf.md>`_ protocol specification
import chronos, chronicles, sequtils
import stew/endians2
import ./core, ../../stream/connection
logScope:
topics = "libp2p perf"
type PerfClient* = ref object of RootObj
proc perf*(
_: typedesc[PerfClient],
conn: Connection,
sizeToWrite: uint64 = 0,
sizeToRead: uint64 = 0,
): Future[Duration] {.async, public.} =
var
size = sizeToWrite
buf: array[PerfSize, byte]
let start = Moment.now()
trace "starting performance benchmark", conn, sizeToWrite, sizeToRead
await conn.write(toSeq(toBytesBE(sizeToRead)))
while size > 0:
let toWrite = min(size, PerfSize)
await conn.write(buf[0 ..< toWrite])
size -= toWrite
await conn.close()
size = sizeToRead
while size > 0:
let toRead = min(size, PerfSize)
await conn.readExactly(addr buf[0], toRead.int)
size = size - toRead
let duration = Moment.now() - start
trace "finishing performance benchmark", duration
return duration

View File

@@ -1,14 +0,0 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
## `Perf <https://github.com/libp2p/specs/blob/master/perf/perf.md>`_ protocol specification
const
PerfCodec* = "/perf/1.0.0"
PerfSize* = 65536

View File

@@ -1,57 +0,0 @@
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
## `Perf <https://github.com/libp2p/specs/blob/master/perf/perf.md>`_ protocol specification
{.push raises: [].}
import chronos, chronicles
import stew/endians2
import ./core, ../protocol, ../../stream/connection, ../../utility
export chronicles, connection
logScope:
topics = "libp2p perf"
type Perf* = ref object of LPProtocol
proc new*(T: typedesc[Perf]): T {.public.} =
var p = T()
proc handle(conn: Connection, proto: string) {.async.} =
var bytesRead = 0
try:
trace "Received benchmark performance check", conn
var
sizeBuffer: array[8, byte]
size: uint64
await conn.readExactly(addr sizeBuffer[0], 8)
size = uint64.fromBytesBE(sizeBuffer)
var toReadBuffer: array[PerfSize, byte]
try:
while true:
bytesRead += await conn.readOnce(addr toReadBuffer[0], PerfSize)
except CatchableError as exc:
discard
var buf: array[PerfSize, byte]
while size > 0:
let toWrite = min(size, PerfSize)
await conn.write(buf[0 ..< toWrite])
size -= toWrite
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception in perf handler", exc = exc.msg, conn
await conn.close()
p.handler = handle
p.codec = PerfCodec
return p

View File

@@ -9,20 +9,22 @@
## `Ping <https://docs.libp2p.io/concepts/protocols/#ping>`_ protocol implementation
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos, chronicles
import bearssl/rand
import
../protobuf/minprotobuf,
../peerinfo,
../stream/connection,
../peerid,
../crypto/crypto,
../multiaddress,
../protocols/protocol,
../utility,
../errors
import bearssl/[rand, hash]
import ../protobuf/minprotobuf,
../peerinfo,
../stream/connection,
../peerid,
../crypto/crypto,
../multiaddress,
../protocols/protocol,
../utility,
../errors
export chronicles, rand, connection
@@ -37,26 +39,27 @@ type
PingError* = object of LPError
WrongPingAckError* = object of PingError
PingHandler* {.public.} = proc(peer: PeerId): Future[void] {.gcsafe, raises: [].}
PingHandler* {.public.} = proc (
peer: PeerId):
Future[void]
{.gcsafe, raises: [Defect].}
Ping* = ref object of LPProtocol
pingHandler*: PingHandler
rng: ref HmacDrbgContext
proc new*(
T: typedesc[Ping], handler: PingHandler = nil, rng: ref HmacDrbgContext = newRng()
): T {.public.} =
proc new*(T: typedesc[Ping], handler: PingHandler = nil, rng: ref HmacDrbgContext = newRng()): T {.public.} =
let ping = Ping(pinghandler: handler, rng: rng)
ping.init()
ping
method init*(p: Ping) =
proc handle(conn: Connection, proto: string) {.async.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} =
try:
trace "handling ping", conn
var buf: array[PingSize, byte]
await conn.readExactly(addr buf[0], PingSize)
trace "echoing ping", conn, pingData = @buf
trace "echoing ping", conn
await conn.write(@buf)
if not isNil(p.pingHandler):
await p.pingHandler(conn.peerId)
@@ -68,7 +71,10 @@ method init*(p: Ping) =
p.handler = handle
p.codec = PingCodec
proc ping*(p: Ping, conn: Connection): Future[Duration] {.async, public.} =
proc ping*(
p: Ping,
conn: Connection,
): Future[Duration] {.async, gcsafe, public.} =
## Sends ping to `conn`, returns the delay
trace "initiating ping", conn
@@ -90,7 +96,7 @@ proc ping*(p: Ping, conn: Connection): Future[Duration] {.async, public.} =
trace "got ping response", conn, responseDur
for i in 0 ..< randomBuf.len:
for i in 0..<randomBuf.len:
if randomBuf[i] != resultBuf[i]:
raise newException(WrongPingAckError, "Incorrect ping data from peer!")

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -7,38 +7,35 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos, stew/results
import ../stream/connection
export results
const DefaultMaxIncomingStreams* = 10
const
DefaultMaxIncomingStreams* = 10
type
LPProtoHandler* = proc(conn: Connection, proto: string): Future[void] {.async.}
LPProtoHandler* = proc (
conn: Connection,
proto: string):
Future[void]
{.gcsafe, raises: [Defect].}
LPProtocol* = ref object of RootObj
codecs*: seq[string]
handlerImpl: LPProtoHandler ## invoked by the protocol negotiator
handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator
started*: bool
maxIncomingStreams: Opt[int]
method init*(p: LPProtocol) {.base, gcsafe.} =
discard
method start*(p: LPProtocol) {.async: (raises: [CancelledError], raw: true), base.} =
let fut = newFuture[void]()
fut.complete()
p.started = true
fut
method stop*(p: LPProtocol) {.async: (raises: [], raw: true), base.} =
let fut = newFuture[void]()
fut.complete()
p.started = false
fut
method init*(p: LPProtocol) {.base, gcsafe.} = discard
method start*(p: LPProtocol) {.async, base.} = p.started = true
method stop*(p: LPProtocol) {.async, base.} = p.started = false
proc maxIncomingStreams*(p: LPProtocol): int =
p.maxIncomingStreams.get(DefaultMaxIncomingStreams)
@@ -47,7 +44,7 @@ proc `maxIncomingStreams=`*(p: LPProtocol, val: int) =
p.maxIncomingStreams = Opt.some(val)
func codec*(p: LPProtocol): string =
doAssert(p.codecs.len > 0, "Codecs sequence was empty!")
assert(p.codecs.len > 0, "Codecs sequence was empty!")
p.codecs[0]
func `codec=`*(p: LPProtocol, codec: string) =
@@ -55,54 +52,15 @@ func `codec=`*(p: LPProtocol, codec: string) =
# if we use this abstraction
p.codecs.insert(codec, 0)
template `handler`*(p: LPProtocol): LPProtoHandler =
p.handlerImpl
template `handler`*(p: LPProtocol, conn: Connection, proto: string): Future[void] =
p.handlerImpl(conn, proto)
func `handler=`*(p: LPProtocol, handler: LPProtoHandler) =
p.handlerImpl = handler
# Callbacks that are annotated with `{.async: (raises).}` explicitly
# document the types of errors that they may raise, but are not compatible
# with `LPProtoHandler` and need to use a custom `proc` type.
# They are internally wrapped into a `LPProtoHandler`, but still allow the
# compiler to check that their `{.async: (raises).}` annotation is correct.
# https://github.com/nim-lang/Nim/issues/23432
func `handler=`*[E](
p: LPProtocol,
handler: proc(conn: Connection, proto: string): InternalRaisesFuture[void, E],
) =
proc wrap(conn: Connection, proto: string): Future[void] {.async.} =
await handler(conn, proto)
p.handlerImpl = wrap
proc new*(
T: type LPProtocol,
codecs: seq[string],
handler: LPProtoHandler,
maxIncomingStreams: Opt[int] | int = Opt.none(int),
): T =
T: type LPProtocol,
codecs: seq[string],
handler: LPProtoHandler, # default(Opt[int]) or Opt.none(int) don't work on 1.2
maxIncomingStreams: Opt[int] | int = Opt[int]()): T =
T(
codecs: codecs,
handlerImpl: handler,
handler: handler,
maxIncomingStreams:
when maxIncomingStreams is int:
Opt.some(maxIncomingStreams)
else:
maxIncomingStreams
,
when maxIncomingStreams is int: Opt.some(maxIncomingStreams)
else: maxIncomingStreams
)
proc new*[E](
T: type LPProtocol,
codecs: seq[string],
handler: proc(conn: Connection, proto: string): InternalRaisesFuture[void, E],
maxIncomingStreams: Opt[int] | int = Opt.none(int),
): T =
proc wrap(conn: Connection, proto: string): Future[void] {.async.} =
await handler(conn, proto)
T.new(codec, wrap, maxIncomingStreams)

View File

@@ -3,7 +3,6 @@
import ../../utility
type ValidationResult* {.pure, public.} = enum
Accept
Reject
Ignore
type
ValidationResult* {.pure, public.} = enum
Accept, Reject, Ignore

View File

@@ -7,22 +7,23 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[sets, hashes, tables]
import std/[sequtils, sets, hashes, tables]
import chronos, chronicles, metrics
import
./pubsub,
./pubsubpeer,
./timedcache,
./peertable,
./rpc/[message, messages, protobuf],
nimcrypto/[hash, sha2],
../../crypto/crypto,
../../stream/connection,
../../peerid,
../../peerinfo,
../../utility
import ./pubsub,
./pubsubpeer,
./timedcache,
./peertable,
./rpc/[message, messages],
../../crypto/crypto,
../../stream/connection,
../../peerid,
../../peerinfo,
../../utility
## Simple flood-based publishing.
@@ -31,33 +32,28 @@ logScope:
const FloodSubCodec* = "/floodsub/1.0.0"
type FloodSub* {.public.} = ref object of PubSub
floodsub*: PeerTable # topic to remote peer map
seen*: TimedCache[SaltedId]
# Early filter for messages recently observed on the network
# We use a salted id because the messages in this cache have not yet
# been validated meaning that an attacker has greater control over the
# hash key and therefore could poison the table
seenSalt*: sha256
# The salt in this case is a partially updated SHA256 context pre-seeded
# with some random data
type
FloodSub* {.public.} = ref object of PubSub
floodsub*: PeerTable # topic to remote peer map
seen*: TimedCache[MessageId] # message id:s already seen on the network
seenSalt*: seq[byte]
proc salt*(f: FloodSub, msgId: MessageId): SaltedId =
var tmp = f.seenSalt
tmp.update(msgId)
SaltedId(data: tmp.finish())
proc hasSeen*(f: FloodSub, msgId: MessageId): bool =
f.seenSalt & msgId in f.seen
proc hasSeen*(f: FloodSub, saltedId: SaltedId): bool =
saltedId in f.seen
proc addSeen*(f: FloodSub, saltedId: SaltedId): bool =
proc addSeen*(f: FloodSub, msgId: MessageId): bool =
# Salting the seen hash helps avoid attacks against the hash function used
# in the nim hash table
# Return true if the message has already been seen
f.seen.put(saltedId)
f.seen.put(f.seenSalt & msgId)
proc firstSeen*(f: FloodSub, saltedId: SaltedId): Moment =
f.seen.addedAt(saltedId)
proc firstSeen*(f: FloodSub, msgId: MessageId): Moment =
f.seen.addedAt(f.seenSalt & msgId)
proc handleSubscribe(f: FloodSub, peer: PubSubPeer, topic: string, subscribe: bool) =
proc handleSubscribe*(f: FloodSub,
peer: PubSubPeer,
topic: string,
subscribe: bool) =
logScope:
peer
topic
@@ -70,8 +66,7 @@ proc handleSubscribe(f: FloodSub, peer: PubSubPeer, topic: string, subscribe: bo
trace "ignoring unknown peer"
return
if subscribe and not (isNil(f.subscriptionValidator)) and
not (f.subscriptionValidator(topic)):
if subscribe and not(isNil(f.subscriptionValidator)) and not(f.subscriptionValidator(topic)):
# this is a violation, so warn should be in order
warn "ignoring invalid topic subscription", topic, peer
return
@@ -101,22 +96,14 @@ method unsubscribePeer*(f: FloodSub, peer: PeerId) =
procCall PubSub(f).unsubscribePeer(peer)
method rpcHandler*(f: FloodSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
var rpcMsg = decodeRpcMsg(data).valueOr:
debug "failed to decode msg from peer", peer, err = error
raise newException(CatchableError, "Peer msg couldn't be decoded")
trace "decoded msg from peer", peer, msg = rpcMsg.shortLog
# trigger hooks
peer.recvObservers(rpcMsg)
for i in 0 ..< min(f.topicsHigh, rpcMsg.subscriptions.len):
template sub(): untyped =
rpcMsg.subscriptions[i]
method rpcHandler*(f: FloodSub,
peer: PubSubPeer,
rpcMsg: RPCMsg) {.async.} =
for i in 0..<min(f.topicsHigh, rpcMsg.subscriptions.len):
template sub: untyped = rpcMsg.subscriptions[i]
f.handleSubscribe(peer, sub.topic, sub.subscribe)
for msg in rpcMsg.messages: # for every message
for msg in rpcMsg.messages: # for every message
let msgIdResult = f.msgIdProvider(msg)
if msgIdResult.isErr:
debug "Dropping message due to failed message id generation",
@@ -124,11 +111,9 @@ method rpcHandler*(f: FloodSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
# TODO: descore peers due to error during message validation (malicious?)
continue
let
msgId = msgIdResult.get
saltedId = f.salt(msgId)
let msgId = msgIdResult.get
if f.addSeen(saltedId):
if f.addSeen(msgId):
trace "Dropping already-seen message", msgId, peer
continue
@@ -157,19 +142,16 @@ method rpcHandler*(f: FloodSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
discard
var toSendPeers = initHashSet[PubSubPeer]()
let topic = msg.topic
if topic notin f.topics:
debug "Dropping message due to topic not in floodsub topics", topic, msgId, peer
continue
for t in msg.topicIds: # for every topic in the message
if t notin f.topics:
continue
f.floodsub.withValue(t, peers): toSendPeers.incl(peers[])
f.floodsub.withValue(topic, peers):
toSendPeers.incl(peers[])
await handleData(f, topic, msg.data)
await handleData(f, t, msg.data)
# In theory, if topics are the same in all messages, we could batch - we'd
# also have to be careful to only include validated messages
f.broadcast(toSendPeers, RPCMsg(messages: @[msg]), isHighPriority = false)
f.broadcast(toSendPeers, RPCMsg(messages: @[msg]))
trace "Forwared message to peers", peers = toSendPeers.len
f.updateMetrics(rpcMsg)
@@ -192,7 +174,9 @@ method init*(f: FloodSub) =
f.handler = handler
f.codec = FloodSubCodec
method publish*(f: FloodSub, topic: string, data: seq[byte]): Future[int] {.async.} =
method publish*(f: FloodSub,
topic: string,
data: seq[byte]): Future[int] {.async.} =
# base returns always 0
discard await procCall PubSub(f).publish(topic, data)
@@ -216,18 +200,20 @@ method publish*(f: FloodSub, topic: string, data: seq[byte]): Future[int] {.asyn
inc f.msgSeqno
Message.init(some(f.peerInfo), data, topic, some(f.msgSeqno), f.sign)
msgId = f.msgIdProvider(msg).valueOr:
trace "Error generating message id, skipping publish", error = error
trace "Error generating message id, skipping publish",
error = error
return 0
trace "Created new message", msg = shortLog(msg), peers = peers.len, topic, msgId
trace "Created new message",
msg = shortLog(msg), peers = peers.len, topic, msgId
if f.addSeen(f.salt(msgId)):
if f.addSeen(msgId):
# custom msgid providers might cause this
trace "Dropping already-seen message", msgId, topic
return 0
# Try to send to all peers that are known to be interested
f.broadcast(peers, RPCMsg(messages: @[msg]), isHighPriority = true)
f.broadcast(peers, RPCMsg(messages: @[msg]))
when defined(libp2p_expensive_metrics):
libp2p_pubsub_messages_published.inc(labelValues = [topic])
@@ -236,13 +222,11 @@ method publish*(f: FloodSub, topic: string, data: seq[byte]): Future[int] {.asyn
return peers.len
method initPubSub*(f: FloodSub) {.raises: [InitializationError].} =
method initPubSub*(f: FloodSub)
{.raises: [Defect, InitializationError].} =
procCall PubSub(f).initPubSub()
f.seen = TimedCache[SaltedId].init(2.minutes)
f.seenSalt.init()
var tmp: array[32, byte]
hmacDrbgGenerate(f.rng[], tmp)
f.seenSalt.update(tmp)
f.seen = TimedCache[MessageId].init(2.minutes)
f.seenSalt = newSeqUninitialized[byte](sizeof(Hash))
hmacDrbgGenerate(f.rng[], f.seenSalt)
f.init()

View File

@@ -1,5 +1,5 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
@@ -9,25 +9,26 @@
## Gossip based publishing
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[sets, sequtils]
import std/[tables, sets, options, sequtils]
import chronos, chronicles, metrics
import chronos/ratelimit
import
./pubsub,
./floodsub,
./pubsubpeer,
./peertable,
./mcache,
./timedcache,
./rpc/[messages, message, protobuf],
../protocol,
../../stream/connection,
../../peerinfo,
../../peerid,
../../utility,
../../switch
import ./pubsub,
./floodsub,
./pubsubpeer,
./peertable,
./mcache,
./timedcache,
./rpc/[messages, message],
../protocol,
../../stream/connection,
../../peerinfo,
../../peerid,
../../utility,
../../switch
import stew/results
export results
@@ -40,112 +41,49 @@ logScope:
topics = "libp2p gossipsub"
declareCounter(libp2p_gossipsub_failed_publish, "number of failed publish")
declareCounter(
libp2p_gossipsub_invalid_topic_subscription,
"number of invalid topic subscriptions that happened",
)
declareCounter(
libp2p_gossipsub_duplicate_during_validation,
"number of duplicates received during message validation",
)
declareCounter(
libp2p_gossipsub_idontwant_saved_messages, "number of duplicates avoided by idontwant"
)
declareCounter(
libp2p_gossipsub_saved_bytes,
"bytes saved by gossipsub optimizations",
labels = ["kind"],
)
declareCounter(libp2p_gossipsub_invalid_topic_subscription, "number of invalid topic subscriptions that happened")
declareCounter(libp2p_gossipsub_duplicate_during_validation, "number of duplicates received during message validation")
declareCounter(libp2p_gossipsub_duplicate, "number of duplicates received")
declareCounter(libp2p_gossipsub_received, "number of messages received (deduplicated)")
when defined(libp2p_expensive_metrics):
declareCounter(
libp2p_pubsub_received_messages,
"number of messages received",
labels = ["id", "topic"],
)
proc init*(
_: type[GossipSubParams],
pruneBackoff = 1.minutes,
unsubscribeBackoff = 5.seconds,
floodPublish = true,
gossipFactor: float64 = 0.25,
d = GossipSubD,
dLow = GossipSubDlo,
dHigh = GossipSubDhi,
dScore = GossipSubDlo,
dOut = GossipSubDlo - 1, # DLow - 1
dLazy = GossipSubD, # Like D,
heartbeatInterval = GossipSubHeartbeatInterval,
historyLength = GossipSubHistoryLength,
historyGossip = GossipSubHistoryGossip,
fanoutTTL = GossipSubFanoutTTL,
seenTTL = 2.minutes,
gossipThreshold = -100.0,
publishThreshold = -1000.0,
graylistThreshold = -10000.0,
opportunisticGraftThreshold = 0.0,
decayInterval = 1.seconds,
decayToZero = 0.01,
retainScore = 2.minutes,
appSpecificWeight = 0.0,
ipColocationFactorWeight = 0.0,
ipColocationFactorThreshold = 1.0,
behaviourPenaltyWeight = -1.0,
behaviourPenaltyDecay = 0.999,
directPeers = initTable[PeerId, seq[MultiAddress]](),
disconnectBadPeers = false,
enablePX = false,
bandwidthEstimatebps = 100_000_000, # 100 Mbps or 12.5 MBps
overheadRateLimit = Opt.none(tuple[bytes: int, interval: Duration]),
disconnectPeerAboveRateLimit = false,
maxNumElementsInNonPriorityQueue = DefaultMaxNumElementsInNonPriorityQueue,
): GossipSubParams =
proc init*(_: type[GossipSubParams]): GossipSubParams =
GossipSubParams(
explicit: true,
pruneBackoff: pruneBackoff,
unsubscribeBackoff: unsubscribeBackoff,
floodPublish: floodPublish,
gossipFactor: gossipFactor,
d: d,
dLow: dLow,
dHigh: dHigh,
dScore: dScore,
dOut: dOut,
dLazy: dLazy,
heartbeatInterval: heartbeatInterval,
historyLength: historyLength,
historyGossip: historyGossip,
fanoutTTL: fanoutTTL,
seenTTL: seenTTL,
gossipThreshold: gossipThreshold,
publishThreshold: publishThreshold,
graylistThreshold: graylistThreshold,
opportunisticGraftThreshold: opportunisticGraftThreshold,
decayInterval: decayInterval,
decayToZero: decayToZero,
retainScore: retainScore,
appSpecificWeight: appSpecificWeight,
ipColocationFactorWeight: ipColocationFactorWeight,
ipColocationFactorThreshold: ipColocationFactorThreshold,
behaviourPenaltyWeight: behaviourPenaltyWeight,
behaviourPenaltyDecay: behaviourPenaltyDecay,
directPeers: directPeers,
disconnectBadPeers: disconnectBadPeers,
enablePX: enablePX,
bandwidthEstimatebps: bandwidthEstimatebps,
overheadRateLimit: overheadRateLimit,
disconnectPeerAboveRateLimit: disconnectPeerAboveRateLimit,
maxNumElementsInNonPriorityQueue: maxNumElementsInNonPriorityQueue,
)
explicit: true,
pruneBackoff: 1.minutes,
unsubscribeBackoff: 5.seconds,
floodPublish: true,
gossipFactor: 0.25,
d: GossipSubD,
dLow: GossipSubDlo,
dHigh: GossipSubDhi,
dScore: GossipSubDlo,
dOut: GossipSubDlo - 1, # DLow - 1
dLazy: GossipSubD, # Like D
heartbeatInterval: GossipSubHeartbeatInterval,
historyLength: GossipSubHistoryLength,
historyGossip: GossipSubHistoryGossip,
fanoutTTL: GossipSubFanoutTTL,
seenTTL: 2.minutes,
gossipThreshold: -100,
publishThreshold: -1000,
graylistThreshold: -10000,
opportunisticGraftThreshold: 0,
decayInterval: 1.seconds,
decayToZero: 0.01,
retainScore: 2.minutes,
appSpecificWeight: 0.0,
ipColocationFactorWeight: 0.0,
ipColocationFactorThreshold: 1.0,
behaviourPenaltyWeight: -1.0,
behaviourPenaltyDecay: 0.999,
disconnectBadPeers: false,
enablePX: false
)
proc validateParameters*(parameters: GossipSubParams): Result[void, cstring] =
if (parameters.dOut >= parameters.dLow) or (parameters.dOut > (parameters.d div 2)):
err(
"gossipsub: dOut parameter error, Number of outbound connections to keep in the mesh. Must be less than D_lo and at most D/2"
)
if (parameters.dOut >= parameters.dLow) or
(parameters.dOut > (parameters.d div 2)):
err("gossipsub: dOut parameter error, Number of outbound connections to keep in the mesh. Must be less than D_lo and at most D/2")
elif parameters.gossipThreshold >= 0:
err("gossipsub: gossipThreshold parameter error, Must be < 0")
elif parameters.unsubscribeBackoff.seconds <= 0:
@@ -170,8 +108,6 @@ proc validateParameters*(parameters: GossipSubParams): Result[void, cstring] =
err("gossipsub: behaviourPenaltyWeight parameter error, Must be negative")
elif parameters.behaviourPenaltyDecay < 0 or parameters.behaviourPenaltyDecay >= 1:
err("gossipsub: behaviourPenaltyDecay parameter error, Must be between 0 and 1")
elif parameters.maxNumElementsInNonPriorityQueue <= 0:
err("gossipsub: maxNumElementsInNonPriorityQueue parameter error, Must be > 0")
else:
ok()
@@ -181,29 +117,17 @@ proc validateParameters*(parameters: TopicParams): Result[void, cstring] =
elif parameters.timeInMeshCap <= 0.0:
err("gossipsub: timeInMeshCap parameter error, Should be a positive value")
elif parameters.firstMessageDeliveriesWeight <= 0.0:
err(
"gossipsub: firstMessageDeliveriesWeight parameter error, Should be a positive value"
)
err("gossipsub: firstMessageDeliveriesWeight parameter error, Should be a positive value")
elif parameters.meshMessageDeliveriesWeight >= 0.0:
err(
"gossipsub: meshMessageDeliveriesWeight parameter error, Should be a negative value"
)
err("gossipsub: meshMessageDeliveriesWeight parameter error, Should be a negative value")
elif parameters.meshMessageDeliveriesThreshold <= 0.0:
err(
"gossipsub: meshMessageDeliveriesThreshold parameter error, Should be a positive value"
)
err("gossipsub: meshMessageDeliveriesThreshold parameter error, Should be a positive value")
elif parameters.meshMessageDeliveriesCap < parameters.meshMessageDeliveriesThreshold:
err(
"gossipsub: meshMessageDeliveriesCap parameter error, Should be >= meshMessageDeliveriesThreshold"
)
err("gossipsub: meshMessageDeliveriesCap parameter error, Should be >= meshMessageDeliveriesThreshold")
elif parameters.meshFailurePenaltyWeight >= 0.0:
err(
"gossipsub: meshFailurePenaltyWeight parameter error, Should be a negative value"
)
err("gossipsub: meshFailurePenaltyWeight parameter error, Should be a negative value")
elif parameters.invalidMessageDeliveriesWeight >= 0.0:
err(
"gossipsub: invalidMessageDeliveriesWeight parameter error, Should be a negative value"
)
err("gossipsub: invalidMessageDeliveriesWeight parameter error, Should be a negative value")
else:
ok()
@@ -223,32 +147,26 @@ method init*(g: GossipSub) =
trace "GossipSub handler leaks an error", exc = exc.msg, conn
g.handler = handler
g.codecs &= GossipSubCodec_12
g.codecs &= GossipSubCodec_11
g.codecs &= GossipSubCodec
g.codecs &= GossipSubCodec_10
method onNewPeer*(g: GossipSub, peer: PubSubPeer) =
g.withPeerStats(peer.peerId) do(stats: var PeerStats):
method onNewPeer(g: GossipSub, peer: PubSubPeer) =
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
# Make sure stats and peer information match, even when reloading peer stats
# from a previous connection
peer.score = stats.score
peer.appScore = stats.appScore
peer.behaviourPenalty = stats.behaviourPenalty
# Check if the score is below the threshold and disconnect the peer if necessary
g.disconnectIfBadScorePeer(peer, stats.score)
peer.iWantBudget = IWantPeerBudget
peer.iHaveBudget = IHavePeerBudget
peer.iHaveBudget = IHavePeerBudget
peer.pingBudget = PingsPeerBudget
method onPubSubPeerEvent*(
p: GossipSub, peer: PubSubPeer, event: PubSubPeerEvent
) {.gcsafe.} =
method onPubSubPeerEvent*(p: GossipSub, peer: PubSubPeer, event: PubSubPeerEvent) {.gcsafe.} =
case event.kind
of PubSubPeerEventKind.StreamOpened:
of PubSubPeerEventKind.Connected:
discard
of PubSubPeerEventKind.StreamClosed:
# If a send stream is lost, it's better to remove peer from the mesh -
of PubSubPeerEventKind.Disconnected:
# If a send connection is lost, it's better to remove peer from the mesh -
# if it gets reestablished, the peer will be readded to the mesh, and if it
# doesn't, well.. then we hope the peer is going away!
for topic, peers in p.mesh.mpairs():
@@ -256,8 +174,6 @@ method onPubSubPeerEvent*(
peers.excl(peer)
for _, peers in p.fanout.mpairs():
peers.excl(peer)
of PubSubPeerEventKind.DisconnectionRequested:
asyncSpawn p.disconnectPeer(peer) # this should unsubscribePeer the peer too
procCall FloodSub(p).onPubSubPeerEvent(peer, event)
@@ -272,11 +188,11 @@ method unsubscribePeer*(g: GossipSub, peer: PeerId) =
return
# remove from peer IPs collection too
pubSubPeer.address.withValue(address):
g.peersInIP.withValue(address, s):
if pubSubPeer.address.isSome():
g.peersInIP.withValue(pubSubPeer.address.get(), s):
s[].excl(pubSubPeer.peerId)
if s[].len == 0:
g.peersInIP.del(address)
g.peersInIP.del(pubSubPeer.address.get())
for t in toSeq(g.mesh.keys):
trace "pruning unsubscribing peer", pubSubPeer, score = pubSubPeer.score
@@ -285,8 +201,8 @@ method unsubscribePeer*(g: GossipSub, peer: PeerId) =
for t in toSeq(g.gossipsub.keys):
g.gossipsub.removePeer(t, pubSubPeer)
# also try to remove from direct peers table here
g.subscribedDirectPeers.removePeer(t, pubSubPeer)
# also try to remove from explicit table here
g.explicit.removePeer(t, pubSubPeer)
for t in toSeq(g.fanout.keys):
g.fanout.removePeer(t, pubSubPeer)
@@ -295,11 +211,12 @@ method unsubscribePeer*(g: GossipSub, peer: PeerId) =
for topic, info in stats[].topicInfos.mpairs:
info.firstMessageDeliveries = 0
pubSubPeer.stopSendNonPriorityTask()
procCall FloodSub(g).unsubscribePeer(peer)
proc handleSubscribe(g: GossipSub, peer: PubSubPeer, topic: string, subscribe: bool) =
proc handleSubscribe*(g: GossipSub,
peer: PubSubPeer,
topic: string,
subscribe: bool) =
logScope:
peer
topic
@@ -313,7 +230,7 @@ proc handleSubscribe(g: GossipSub, peer: PubSubPeer, topic: string, subscribe: b
trace "ignoring unknown peer"
return
if not (isNil(g.subscriptionValidator)) and not (g.subscriptionValidator(topic)):
if not(isNil(g.subscriptionValidator)) and not(g.subscriptionValidator(topic)):
# this is a violation, so warn should be in order
trace "ignoring invalid topic subscription", topic, peer
libp2p_gossipsub_invalid_topic_subscription.inc()
@@ -324,7 +241,7 @@ proc handleSubscribe(g: GossipSub, peer: PubSubPeer, topic: string, subscribe: b
# subscribe remote peer to the topic
discard g.gossipsub.addPeer(topic, peer)
if peer.peerId in g.parameters.directPeers:
discard g.subscribedDirectPeers.addPeer(topic, peer)
discard g.explicit.addPeer(topic, peer)
else:
trace "peer unsubscribed from topic"
@@ -338,7 +255,7 @@ proc handleSubscribe(g: GossipSub, peer: PubSubPeer, topic: string, subscribe: b
g.fanout.removePeer(topic, peer)
if peer.peerId in g.parameters.directPeers:
g.subscribedDirectPeers.removePeer(topic, peer)
g.explicit.removePeer(topic, peer)
trace "gossip peers", peers = g.gossipsub.peers(topic), topic
@@ -346,97 +263,54 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
g.handlePrune(peer, control.prune)
var respControl: ControlMessage
g.handleIDontWant(peer, control.idontwant)
let iwant = g.handleIHave(peer, control.ihave)
if iwant.messageIDs.len > 0:
if iwant.messageIds.len > 0:
respControl.iwant.add(iwant)
respControl.prune.add(g.handleGraft(peer, control.graft))
let messages = g.handleIWant(peer, control.iwant)
let
isPruneNotEmpty = respControl.prune.len > 0
isIWantNotEmpty = respControl.iwant.len > 0
if
respControl.prune.len > 0 or
respControl.iwant.len > 0 or
messages.len > 0:
# iwant and prunes from here, also messages
if isPruneNotEmpty or isIWantNotEmpty:
if isIWantNotEmpty:
libp2p_pubsub_broadcast_iwant.inc(respControl.iwant.len.int64)
if isPruneNotEmpty:
for prune in respControl.prune:
if g.knownTopics.contains(prune.topicID):
libp2p_pubsub_broadcast_prune.inc(labelValues = [prune.topicID])
for smsg in messages:
for topic in smsg.topicIds:
if g.knownTopics.contains(topic):
libp2p_pubsub_broadcast_messages.inc(labelValues = [topic])
else:
libp2p_pubsub_broadcast_prune.inc(labelValues = ["generic"])
libp2p_pubsub_broadcast_messages.inc(labelValues = ["generic"])
libp2p_pubsub_broadcast_iwant.inc(respControl.iwant.len.int64)
for prune in respControl.prune:
if g.knownTopics.contains(prune.topicId):
libp2p_pubsub_broadcast_prune.inc(labelValues = [prune.topicId])
else:
libp2p_pubsub_broadcast_prune.inc(labelValues = ["generic"])
trace "sending control message", msg = shortLog(respControl), peer
g.send(peer, RPCMsg(control: some(respControl)), isHighPriority = true)
g.send(
peer,
RPCMsg(control: some(respControl), messages: messages))
if messages.len > 0:
for smsg in messages:
let topic = smsg.topic
if g.knownTopics.contains(topic):
libp2p_pubsub_broadcast_messages.inc(labelValues = [topic])
else:
libp2p_pubsub_broadcast_messages.inc(labelValues = ["generic"])
# iwant replies have lower priority
trace "sending iwant reply messages", peer
g.send(peer, RPCMsg(messages: messages), isHighPriority = false)
proc validateAndRelay(
g: GossipSub, msg: Message, msgId: MessageId, saltedId: SaltedId, peer: PubSubPeer
) {.async.} =
proc validateAndRelay(g: GossipSub,
msg: Message,
msgId, msgIdSalted: MessageId,
peer: PubSubPeer) {.async.} =
try:
template topic(): string =
msg.topic
proc addToSendPeers(toSendPeers: var HashSet[PubSubPeer]) =
g.floodsub.withValue(topic, peers):
toSendPeers.incl(peers[])
g.mesh.withValue(topic, peers):
toSendPeers.incl(peers[])
g.subscribedDirectPeers.withValue(topic, peers):
toSendPeers.incl(peers[])
toSendPeers.excl(peer)
if msg.data.len > max(512, msgId.len * 10):
# If the message is "large enough", let the mesh know that we do not want
# any more copies of it, regardless if it is valid or not.
#
# In the case that it is not valid, this leads to some redundancy
# (since the other peer should not send us an invalid message regardless),
# but the expectation is that this is rare (due to such peers getting
# descored) and that the savings from honest peers are greater than the
# cost a dishonest peer can incur in short time (since the IDONTWANT is
# small).
var peersToSendIDontWant = HashSet[PubSubPeer]()
addToSendPeers(peersToSendIDontWant)
peersToSendIDontWant.exclIfIt(
it.codec == GossipSubCodec_10 or it.codec == GossipSubCodec_11
)
g.broadcast(
peersToSendIDontWant,
RPCMsg(
control:
some(ControlMessage(idontwant: @[ControlIWant(messageIDs: @[msgId])]))
),
isHighPriority = true,
)
let validation = await g.validate(msg)
var seenPeers: HashSet[PubSubPeer]
discard g.validationSeen.pop(saltedId, seenPeers)
discard g.validationSeen.pop(msgIdSalted, seenPeers)
libp2p_gossipsub_duplicate_during_validation.inc(seenPeers.len.int64)
libp2p_gossipsub_saved_bytes.inc(
(msg.data.len * seenPeers.len).int64, labelValues = ["validation_duplicate"]
)
case validation
of ValidationResult.Reject:
debug "Dropping message after validation, reason: reject",
msgId = shortLog(msgId), peer
await g.punishInvalidMessage(peer, msg)
g.punishInvalidMessage(peer, msg.topicIds)
return
of ValidationResult.Ignore:
debug "Dropping message after validation, reason: ignore",
@@ -445,141 +319,72 @@ proc validateAndRelay(
of ValidationResult.Accept:
discard
if topic notin g.topics:
return # Topic was unsubscribed while validating
# store in cache only after validation
g.mcache.put(msgId, msg)
g.rewardDelivered(peer, topic, true)
g.rewardDelivered(peer, msg.topicIds, true)
# The send list typically matches the idontwant list from above, but
# might differ if validation takes time
var toSendPeers = HashSet[PubSubPeer]()
addToSendPeers(toSendPeers)
# Don't send it to peers that sent it during validation
for t in msg.topicIds: # for every topic in the message
if t notin g.topics:
continue
g.floodsub.withValue(t, peers): toSendPeers.incl(peers[])
g.mesh.withValue(t, peers): toSendPeers.incl(peers[])
# Don't send it to source peer, or peers that
# sent it during validation
toSendPeers.excl(peer)
toSendPeers.excl(seenPeers)
proc isMsgInIdontWant(it: PubSubPeer): bool =
for iDontWant in it.iDontWants:
if saltedId in iDontWant:
libp2p_gossipsub_idontwant_saved_messages.inc
libp2p_gossipsub_saved_bytes.inc(
msg.data.len.int64, labelValues = ["idontwant"]
)
return true
return false
toSendPeers.exclIfIt(isMsgInIdontWant(it))
# In theory, if topics are the same in all messages, we could batch - we'd
# also have to be careful to only include validated messages
g.broadcast(toSendPeers, RPCMsg(messages: @[msg]), isHighPriority = false)
trace "forwarded message to peers", peers = toSendPeers.len, msgId, peer
g.broadcast(toSendPeers, RPCMsg(messages: @[msg]))
trace "forwared message to peers", peers = toSendPeers.len, msgId, peer
for topic in msg.topicIds:
if topic notin g.topics: continue
if g.knownTopics.contains(topic):
libp2p_pubsub_messages_rebroadcasted.inc(
toSendPeers.len.int64, labelValues = [topic]
)
else:
libp2p_pubsub_messages_rebroadcasted.inc(
toSendPeers.len.int64, labelValues = ["generic"]
)
await handleData(g, topic, msg.data)
except CatchableError as exc:
info "validateAndRelay failed", msg = exc.msg
proc dataAndTopicsIdSize(msgs: seq[Message]): int =
msgs.mapIt(it.data.len + it.topic.len).foldl(a + b, 0)
proc messageOverhead(g: GossipSub, msg: RPCMsg, msgSize: int): int =
# In this way we count even ignored fields by protobuf
let
payloadSize =
if g.verifySignature:
byteSize(msg.messages)
if g.knownTopics.contains(topic):
libp2p_pubsub_messages_rebroadcasted.inc(toSendPeers.len.int64, labelValues = [topic])
else:
dataAndTopicsIdSize(msg.messages)
controlSize = msg.control.withValue(control):
byteSize(control.ihave) + byteSize(control.iwant)
do:
0
libp2p_pubsub_messages_rebroadcasted.inc(toSendPeers.len.int64, labelValues = ["generic"])
msgSize - payloadSize - controlSize
proc rateLimit*(g: GossipSub, peer: PubSubPeer, overhead: int) {.async.} =
peer.overheadRateLimitOpt.withValue(overheadRateLimit):
if not overheadRateLimit.tryConsume(overhead):
libp2p_gossipsub_peers_rate_limit_hits.inc(labelValues = [peer.getAgent()])
# let's just measure at the beginning for test purposes.
debug "Peer sent too much useless application data and it's above rate limit.",
peer, overhead
if g.parameters.disconnectPeerAboveRateLimit:
await g.disconnectPeer(peer)
raise newException(
PeerRateLimitError, "Peer disconnected because it's above rate limit."
)
method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
let msgSize = data.len
var rpcMsg = decodeRpcMsg(data).valueOr:
debug "failed to decode msg from peer", peer, err = error
await rateLimit(g, peer, msgSize)
# Raising in the handler closes the gossipsub connection (but doesn't
# disconnect the peer!)
# TODO evaluate behaviour penalty values
peer.behaviourPenalty += 0.1
raise newException(CatchableError, "Peer msg couldn't be decoded")
when defined(libp2p_expensive_metrics):
for m in rpcMsg.messages:
libp2p_pubsub_received_messages.inc(labelValues = [$peer.peerId, m.topic])
trace "decoded msg from peer", peer, msg = rpcMsg.shortLog
await rateLimit(g, peer, g.messageOverhead(rpcMsg, msgSize))
# trigger hooks - these may modify the message
peer.recvObservers(rpcMsg)
if rpcMsg.ping.len in 1 ..< 64 and peer.pingBudget > 0:
g.send(peer, RPCMsg(pong: rpcMsg.ping), isHighPriority = true)
peer.pingBudget.dec
for i in 0 ..< min(g.topicsHigh, rpcMsg.subscriptions.len):
template sub(): untyped =
rpcMsg.subscriptions[i]
await handleData(g, topic, msg.data)
except CatchableError as exc:
info "validateAndRelay failed", msg=exc.msg
method rpcHandler*(g: GossipSub,
peer: PubSubPeer,
rpcMsg: RPCMsg) {.async.} =
for i in 0..<min(g.topicsHigh, rpcMsg.subscriptions.len):
template sub: untyped = rpcMsg.subscriptions[i]
g.handleSubscribe(peer, sub.topic, sub.subscribe)
# the above call applied limits to subs number
# in gossipsub we want to apply scoring as well
if rpcMsg.subscriptions.len > g.topicsHigh:
debug "received an rpc message with an oversized amount of subscriptions",
peer, size = rpcMsg.subscriptions.len, limit = g.topicsHigh
debug "received an rpc message with an oversized amount of subscriptions", peer,
size = rpcMsg.subscriptions.len,
limit = g.topicsHigh
peer.behaviourPenalty += 0.1
for i in 0 ..< rpcMsg.messages.len(): # for every message
template msg(): untyped =
rpcMsg.messages[i]
template topic(): string =
msg.topic
for i in 0..<rpcMsg.messages.len(): # for every message
template msg: untyped = rpcMsg.messages[i]
let msgIdResult = g.msgIdProvider(msg)
if msgIdResult.isErr:
debug "Dropping message due to failed message id generation",
error = msgIdResult.error
await g.punishInvalidMessage(peer, msg)
# TODO: descore peers due to error during message validation (malicious?)
continue
let
msgId = msgIdResult.get
msgIdSalted = g.salt(msgId)
msgIdSalted = msgId & g.seenSalt
if g.addSeen(msgIdSalted):
# addSeen adds salt to msgId to avoid
# remote attacking the hash function
if g.addSeen(msgId):
trace "Dropping already-seen message", msgId = shortLog(msgId), peer
var alreadyReceived = false
@@ -589,8 +394,8 @@ method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
alreadyReceived = true
if not alreadyReceived:
let delay = Moment.now() - g.firstSeen(msgIdSalted)
g.rewardDelivered(peer, topic, false, delay)
let delay = Moment.now() - g.firstSeen(msgId)
g.rewardDelivered(peer, msg.topicIds, false, delay)
libp2p_gossipsub_duplicate.inc()
@@ -600,23 +405,22 @@ method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
libp2p_gossipsub_received.inc()
# avoid processing messages we are not interested in
if topic notin g.topics:
debug "Dropping message of topic without subscription",
msgId = shortLog(msgId), peer
if msg.topicIds.allIt(it notin g.topics):
debug "Dropping message of topic without subscription", msgId = shortLog(msgId), peer
continue
if (msg.signature.len > 0 or g.verifySignature) and not msg.verify():
# always validate if signature is present or required
debug "Dropping message due to failed signature verification",
msgId = shortLog(msgId), peer
await g.punishInvalidMessage(peer, msg)
g.punishInvalidMessage(peer, msg.topicIds)
continue
if msg.seqno.len > 0 and msg.seqno.len != 8:
# if we have seqno should be 8 bytes long
debug "Dropping message due to invalid seqno length",
msgId = shortLog(msgId), peer
await g.punishInvalidMessage(peer, msg)
g.punishInvalidMessage(peer, msg.topicIds)
continue
# g.anonymize needs no evaluation when receiving messages
@@ -632,7 +436,7 @@ method rpcHandler*(g: GossipSub, peer: PubSubPeer, data: seq[byte]) {.async.} =
g.handleControl(peer, rpcMsg.control.unsafeGet())
# Now, check subscription to update the meshes if required
for i in 0 ..< min(g.topicsHigh, rpcMsg.subscriptions.len):
for i in 0..<min(g.topicsHigh, rpcMsg.subscriptions.len):
let topic = rpcMsg.subscriptions[i].topic
if topic in g.topics and g.mesh.peers(topic) < g.parameters.dLow:
# rebalance but don't update metrics here, we do that only in the heartbeat
@@ -655,101 +459,83 @@ method onTopicSubscription*(g: GossipSub, topic: string, subscribed: bool) =
# Remove peers from the mesh since we're no longer both interested
# in the topic
let msg = RPCMsg(
control: some(
ControlMessage(
prune:
@[
ControlPrune(
topicID: topic,
peers: g.peerExchangeList(topic),
backoff: g.parameters.unsubscribeBackoff.seconds.uint64,
)
]
)
)
)
g.broadcast(mpeers, msg, isHighPriority = true)
let msg = RPCMsg(control: some(ControlMessage(
prune: @[ControlPrune(
topicID: topic,
peers: g.peerExchangeList(topic),
backoff: g.parameters.unsubscribeBackoff.seconds.uint64)])))
g.broadcast(mpeers, msg)
for peer in mpeers:
g.pruned(peer, topic, backoff = some(g.parameters.unsubscribeBackoff))
g.mesh.del(topic)
# Send unsubscribe (in reverse order to sub/graft)
procCall PubSub(g).onTopicSubscription(topic, subscribed)
method publish*(g: GossipSub, topic: string, data: seq[byte]): Future[int] {.async.} =
method publish*(g: GossipSub,
topic: string,
data: seq[byte]): Future[int] {.async.} =
# base returns always 0
discard await procCall PubSub(g).publish(topic, data)
logScope:
topic
trace "Publishing message on topic", data = data.shortLog
if topic.len <= 0: # data could be 0/empty
debug "Empty topic, skipping publish"
return 0
# base returns always 0
discard await procCall PubSub(g).publish(topic, data)
trace "Publishing message on topic", data = data.shortLog
var peers: HashSet[PubSubPeer]
if g.parameters.floodPublish:
# With flood publishing enabled, the mesh is used when propagating messages from other peers,
# but a peer's own messages will always be published to all known peers in the topic.
for peer in g.gossipsub.getOrDefault(topic):
if peer.score >= g.parameters.publishThreshold:
trace "publish: including flood/high score peer", peer
peers.incl(peer)
# add always direct peers
peers.incl(g.subscribedDirectPeers.getOrDefault(topic))
peers.incl(g.explicit.getOrDefault(topic))
if topic in g.topics: # if we're subscribed use the mesh
peers.incl(g.mesh.getOrDefault(topic))
if g.parameters.floodPublish:
# With flood publishing enabled, the mesh is used when propagating messages from other peers,
# but a peer's own messages will always be published to all known peers in the topic, limited
# to the amount of peers we can send it to in one heartbeat
let maxPeersToFlood =
if g.parameters.bandwidthEstimatebps > 0:
let
bandwidth = (g.parameters.bandwidthEstimatebps) div 8 div 1000
# Divisions are to convert it to Bytes per ms TODO replace with bandwidth estimate
msToTransmit = max(data.len div bandwidth, 1)
max(
g.parameters.heartbeatInterval.milliseconds div msToTransmit,
g.parameters.dLow,
)
else:
int.high() # unlimited
for peer in g.gossipsub.getOrDefault(topic):
if peers.len >= maxPeersToFlood:
break
if peer.score >= g.parameters.publishThreshold:
trace "publish: including flood/high score peer", peer
peers.incl(peer)
elif peers.len < g.parameters.dLow:
if peers.len < g.parameters.dLow and g.parameters.floodPublish == false:
# not subscribed or bad mesh, send to fanout peers
# when flood-publishing, fanout won't help since all potential peers have
# already been added
g.replenishFanout(topic) # Make sure fanout is populated
# disable for floodPublish, since we already sent to every good peer
#
var fanoutPeers = g.fanout.getOrDefault(topic).toSeq()
if fanoutPeers.len == 0:
g.replenishFanout(topic)
fanoutPeers = g.fanout.getOrDefault(topic).toSeq()
g.rng.shuffle(fanoutPeers)
if fanoutPeers.len + peers.len > g.parameters.d:
fanoutPeers.setLen(g.parameters.d - peers.len)
for fanPeer in fanoutPeers:
peers.incl(fanPeer)
if peers.len > g.parameters.d:
break
if peers.len > g.parameters.d: break
# Attempting to publish counts as fanout send (even if the message
# ultimately is not sent)
# even if we couldn't publish,
# we still attempted to publish
# on the topic, so it makes sense
# to update the last topic publish
# time
g.lastFanoutPubSub[topic] = Moment.fromNow(g.parameters.fanoutTTL)
if peers.len == 0:
let topicPeers = g.gossipsub.getOrDefault(topic).toSeq()
debug "No peers for topic, skipping publish",
peersOnTopic = topicPeers.len,
connectedPeers = topicPeers.filterIt(it.connected).len,
topic
debug "No peers for topic, skipping publish", peersOnTopic = topicPeers.len,
connectedPeers = topicPeers.filterIt(it.connected).len,
topic
# skipping topic as our metrics finds that heavy
libp2p_gossipsub_failed_publish.inc()
return 0
@@ -761,42 +547,39 @@ method publish*(g: GossipSub, topic: string, data: seq[byte]): Future[int] {.asy
inc g.msgSeqno
Message.init(some(g.peerInfo), data, topic, some(g.msgSeqno), g.sign)
msgId = g.msgIdProvider(msg).valueOr:
trace "Error generating message id, skipping publish", error = error
trace "Error generating message id, skipping publish",
error = error
libp2p_gossipsub_failed_publish.inc()
return 0
logScope:
msgId = shortLog(msgId)
logScope: msgId = shortLog(msgId)
trace "Created new message", msg = shortLog(msg), peers = peers.len
if g.addSeen(g.salt(msgId)):
# If the message was received or published recently, don't re-publish it -
# this might happen when not using sequence id:s and / or with a custom
# message id provider
if g.addSeen(msgId):
# custom msgid providers might cause this
trace "Dropping already-seen message"
return 0
g.mcache.put(msgId, msg)
g.broadcast(peers, RPCMsg(messages: @[msg]), isHighPriority = true)
g.broadcast(peers, RPCMsg(messages: @[msg]))
if g.knownTopics.contains(topic):
libp2p_pubsub_messages_published.inc(peers.len.int64, labelValues = [topic])
else:
libp2p_pubsub_messages_published.inc(peers.len.int64, labelValues = ["generic"])
trace "Published message to peers", peers = peers.len
trace "Published message to peers", peers=peers.len
return peers.len
proc maintainDirectPeer(g: GossipSub, id: PeerId, addrs: seq[MultiAddress]) {.async.} =
if id notin g.peers:
let peer = g.peers.getOrDefault(id)
if isNil(peer):
trace "Attempting to dial a direct peer", peer = id
if g.switch.isConnected(id):
warn "We are connected to a direct peer, but it isn't a GossipSub peer!", id
return
try:
await g.switch.connect(id, addrs, forceDial = true)
await g.switch.connect(id, addrs)
# populate the peer after it's connected
discard g.getOrCreatePeer(id, g.codecs)
except CancelledError as exc:
@@ -814,42 +597,33 @@ proc maintainDirectPeers(g: GossipSub) {.async.} =
for id, addrs in g.parameters.directPeers:
await g.addDirectPeer(id, addrs)
method start*(
g: GossipSub
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
let fut = newFuture[void]()
fut.complete()
method start*(g: GossipSub) {.async.} =
trace "gossipsub start"
if not g.heartbeatFut.isNil:
warn "Starting gossipsub twice"
return fut
return
g.heartbeatFut = g.heartbeat()
g.scoringHeartbeatFut = g.scoringHeartbeat()
g.directPeersLoop = g.maintainDirectPeers()
g.started = true
fut
method stop*(g: GossipSub): Future[void] {.async: (raises: [], raw: true).} =
let fut = newFuture[void]()
fut.complete()
method stop*(g: GossipSub) {.async.} =
trace "gossipsub stop"
g.started = false
if g.heartbeatFut.isNil:
warn "Stopping gossipsub without starting it"
return fut
return
# stop heartbeat interval
g.directPeersLoop.cancel()
g.scoringHeartbeatFut.cancel()
g.heartbeatFut.cancel()
g.heartbeatFut = nil
fut
method initPubSub*(g: GossipSub) {.raises: [InitializationError].} =
method initPubSub*(g: GossipSub)
{.raises: [Defect, InitializationError].} =
procCall FloodSub(g).initPubSub()
if not g.parameters.explicit:
@@ -860,20 +634,7 @@ method initPubSub*(g: GossipSub) {.raises: [InitializationError].} =
raise newException(InitializationError, $validationRes.error)
# init the floodsub stuff here, we customize timedcache in gossip!
g.seen = TimedCache[SaltedId].init(g.parameters.seenTTL)
g.seen = TimedCache[MessageId].init(g.parameters.seenTTL)
# init gossip stuff
g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength)
method getOrCreatePeer*(
g: GossipSub,
peerId: PeerId,
protosToDial: seq[string],
protoNegotiated: string = "",
): PubSubPeer =
let peer = procCall PubSub(g).getOrCreatePeer(peerId, protosToDial, protoNegotiated)
g.parameters.overheadRateLimit.withValue(overheadRateLimit):
peer.overheadRateLimitOpt =
Opt.some(TokenBucket.new(overheadRateLimit.bytes, overheadRateLimit.interval))
peer.maxNumElementsInNonPriorityQueue = g.parameters.maxNumElementsInNonPriorityQueue
return peer

View File

@@ -7,62 +7,34 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[tables, sequtils, sets, algorithm, deques]
import std/[tables, sequtils, sets, algorithm]
import chronos, chronicles, metrics
import "."/[types, scoring]
import ".."/[pubsubpeer, peertable, mcache, floodsub, pubsub]
import ".."/[pubsubpeer, peertable, timedcache, mcache, floodsub, pubsub]
import "../rpc"/[messages]
import
"../../.."/[
peerid,
multiaddress,
utility,
switch,
routing_record,
signed_envelope,
utils/heartbeat,
]
import "../../.."/[peerid, multiaddress, utility, switch, routing_record, signed_envelope, utils/heartbeat]
logScope:
topics = "libp2p gossipsub"
declareGauge(libp2p_gossipsub_cache_window_size, "the number of messages in the cache")
declareGauge(
libp2p_gossipsub_peers_per_topic_mesh,
"gossipsub peers per topic in mesh",
labels = ["topic"],
)
declareGauge(
libp2p_gossipsub_peers_per_topic_fanout,
"gossipsub peers per topic in fanout",
labels = ["topic"],
)
declareGauge(
libp2p_gossipsub_peers_per_topic_gossipsub,
"gossipsub peers per topic in gossipsub",
labels = ["topic"],
)
declareGauge(libp2p_gossipsub_peers_per_topic_mesh, "gossipsub peers per topic in mesh", labels = ["topic"])
declareGauge(libp2p_gossipsub_peers_per_topic_fanout, "gossipsub peers per topic in fanout", labels = ["topic"])
declareGauge(libp2p_gossipsub_peers_per_topic_gossipsub, "gossipsub peers per topic in gossipsub", labels = ["topic"])
declareGauge(libp2p_gossipsub_under_dout_topics, "number of topics below dout")
declareGauge(libp2p_gossipsub_no_peers_topics, "number of topics in mesh with no peers")
declareGauge(
libp2p_gossipsub_low_peers_topics,
"number of topics in mesh with at least one but below dlow peers",
)
declareGauge(
libp2p_gossipsub_healthy_peers_topics,
"number of topics in mesh with at least dlow peers (but below dhigh)",
)
declareCounter(
libp2p_gossipsub_above_dhigh_condition,
"number of above dhigh pruning branches ran",
labels = ["topic"],
)
declareGauge(libp2p_gossipsub_received_iwants, "received iwants", labels = ["kind"])
declareGauge(libp2p_gossipsub_low_peers_topics, "number of topics in mesh with at least one but below dlow peers")
declareGauge(libp2p_gossipsub_healthy_peers_topics, "number of topics in mesh with at least dlow peers (but below dhigh)")
declareCounter(libp2p_gossipsub_above_dhigh_condition, "number of above dhigh pruning branches ran", labels = ["topic"])
declareSummary(libp2p_gossipsub_mcache_hit, "ratio of successful IWANT message cache lookups")
proc grafted*(g: GossipSub, p: PubSubPeer, topic: string) =
g.withPeerStats(p.peerId) do(stats: var PeerStats):
proc grafted*(g: GossipSub, p: PubSubPeer, topic: string) {.raises: [Defect].} =
g.withPeerStats(p.peerId) do (stats: var PeerStats):
var info = stats.topicInfos.getOrDefault(topic)
info.graftTime = Moment.now()
info.meshTime = 0.seconds
@@ -71,55 +43,57 @@ proc grafted*(g: GossipSub, p: PubSubPeer, topic: string) =
stats.topicInfos[topic] = info
trace "grafted", peer = p, topic
trace "grafted", peer=p, topic
proc pruned*(
g: GossipSub,
p: PubSubPeer,
topic: string,
setBackoff: bool = true,
backoff = none(Duration),
) =
proc pruned*(g: GossipSub,
p: PubSubPeer,
topic: string,
setBackoff: bool = true,
backoff = none(Duration)) {.raises: [Defect].} =
if setBackoff:
let
backoffDuration = backoff.get(g.parameters.pruneBackoff)
backoffDuration =
if isSome(backoff): backoff.get()
else: g.parameters.pruneBackoff
backoffMoment = Moment.fromNow(backoffDuration)
g.backingOff.mgetOrPut(topic, initTable[PeerId, Moment]())[p.peerId] = backoffMoment
g.backingOff
.mgetOrPut(topic, initTable[PeerId, Moment]())[p.peerId] = backoffMoment
g.peerStats.withValue(p.peerId, stats):
stats.topicInfos.withValue(topic, info):
g.topicParams.withValue(topic, topicParams):
# penalize a peer that delivered no message
let threshold = topicParams[].meshMessageDeliveriesThreshold
if info[].inMesh and info[].meshMessageDeliveriesActive and
if info[].inMesh and
info[].meshMessageDeliveriesActive and
info[].meshMessageDeliveries < threshold:
let deficit = threshold - info.meshMessageDeliveries
info[].meshFailurePenalty += deficit * deficit
info.inMesh = false
trace "pruned", peer = p, topic
trace "pruned", peer=p, topic
proc handleBackingOff*(t: var BackoffTable, topic: string) =
proc handleBackingOff*(t: var BackoffTable, topic: string) {.raises: [Defect].} =
let now = Moment.now()
var expired = toSeq(t.getOrDefault(topic).pairs())
expired.keepIf do(pair: tuple[peer: PeerId, expire: Moment]) -> bool:
expired.keepIf do (pair: tuple[peer: PeerId, expire: Moment]) -> bool:
now >= pair.expire
for (peer, _) in expired:
t.withValue(topic, v):
v[].del(peer)
proc peerExchangeList*(g: GossipSub, topic: string): seq[PeerInfoMsg] =
proc peerExchangeList*(g: GossipSub, topic: string): seq[PeerInfoMsg] {.raises: [Defect].} =
if not g.parameters.enablePX:
return @[]
var peers = g.gossipsub.getOrDefault(topic, initHashSet[PubSubPeer]()).toSeq()
peers.keepIf do(x: PubSubPeer) -> bool:
x.score >= 0.0
peers.keepIf do (x: PubSubPeer) -> bool:
x.score >= 0.0
# by spec, larger then Dhi, but let's put some hard caps
peers.setLen(min(peers.len, g.parameters.dHigh * 2))
let sprBook = g.switch.peerStore[SPRBook]
peers.map do(x: PubSubPeer) -> PeerInfoMsg:
peers.map do (x: PubSubPeer) -> PeerInfoMsg:
PeerInfoMsg(
peerId: x.peerId,
signedPeerRecord:
@@ -127,65 +101,52 @@ proc peerExchangeList*(g: GossipSub, topic: string): seq[PeerInfoMsg] =
sprBook[x.peerId].encode().get(default(seq[byte]))
else:
default(seq[byte])
,
)
proc handleGraft*(
g: GossipSub, peer: PubSubPeer, grafts: seq[ControlGraft]
): seq[ControlPrune] =
var prunes: seq[ControlPrune]
for graft in grafts:
let topic = graft.topicID
trace "peer grafted topicID", peer, topic
# It is an error to GRAFT on a direct peer
if peer.peerId in g.parameters.directPeers:
# receiving a graft from a direct peer should yield a more prominent warning (protocol violation)
# we are trusting direct peer not to abuse this
warn "a direct peer attempted to graft us, peering agreements should be reciprocal",
peer, topic
# and such an attempt should be logged and rejected with a PRUNE
prunes.add(
ControlPrune(
topicID: topic,
peers: @[],
# omitting heavy computation here as the remote did something illegal
backoff: g.parameters.pruneBackoff.seconds.uint64,
)
)
let backoff = Moment.fromNow(g.parameters.pruneBackoff)
proc handleGraft*(g: GossipSub,
peer: PubSubPeer,
grafts: seq[ControlGraft]): seq[ControlPrune] = # {.raises: [Defect].} TODO chronicles exception on windows
var prunes: seq[ControlPrune]
for graft in grafts:
let topic = graft.topicId
trace "peer grafted topic", peer, topic
g.backingOff.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] = backoff
# It is an error to GRAFT on a explicit peer
if peer.peerId in g.parameters.directPeers:
# receiving a graft from a direct peer should yield a more prominent warning (protocol violation)
warn "an explicit peer attempted to graft us, peering agreements should be reciprocal",
peer, topic
# and such an attempt should be logged and rejected with a PRUNE
prunes.add(ControlPrune(
topicID: topic,
peers: @[], # omitting heavy computation here as the remote did something illegal
backoff: g.parameters.pruneBackoff.seconds.uint64))
let backoff = Moment.fromNow(g.parameters.pruneBackoff)
g.backingOff
.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] = backoff
peer.behaviourPenalty += 0.1
continue
if g.mesh.hasPeer(topic, peer):
trace "peer already in mesh", peer, topic
continue
# Check backingOff
# Ignore BackoffSlackTime here, since this only for outbound activity
# and subtract a second time to avoid race conditions
# (peers may wait to graft us as the exact instant they're allowed to)
if g.backingOff.getOrDefault(topic).getOrDefault(peer.peerId) -
(BackoffSlackTime * 2).seconds > Moment.now():
if g.backingOff
.getOrDefault(topic)
.getOrDefault(peer.peerId) - (BackoffSlackTime * 2).seconds > Moment.now():
debug "a backingOff peer attempted to graft us", peer, topic
# and such an attempt should be logged and rejected with a PRUNE
prunes.add(
ControlPrune(
topicID: topic,
peers: @[],
# omitting heavy computation here as the remote did something illegal
backoff: g.parameters.pruneBackoff.seconds.uint64,
)
)
prunes.add(ControlPrune(
topicID: topic,
peers: @[], # omitting heavy computation here as the remote did something illegal
backoff: g.parameters.pruneBackoff.seconds.uint64))
let backoff = Moment.fromNow(g.parameters.pruneBackoff)
g.backingOff.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] = backoff
g.backingOff
.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] = backoff
peer.behaviourPenalty += 0.1
@@ -199,8 +160,7 @@ proc handleGraft*(
# If they send us a graft before they send us a subscribe, what should
# we do? For now, we add them to mesh but don't add them to gossipsub.
if topic in g.topics:
if g.mesh.peers(topic) < g.parameters.dHigh or
(peer.outbound and g.mesh.outboundPeers(topic) < g.parameters.dOut):
if g.mesh.peers(topic) < g.parameters.dHigh or peer.outbound:
# In the spec, there's no mention of DHi here, but implicitly, a
# peer will be removed from the mesh on next rebalance, so we don't want
# this peer to push someone else out
@@ -212,148 +172,140 @@ proc handleGraft*(
else:
trace "pruning grafting peer, mesh full",
peer, topic, score = peer.score, mesh = g.mesh.peers(topic)
prunes.add(
ControlPrune(
topicID: topic,
peers: g.peerExchangeList(topic),
backoff: g.parameters.pruneBackoff.seconds.uint64,
)
)
let backoff = Moment.fromNow(g.parameters.pruneBackoff)
g.backingOff.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] =
backoff
prunes.add(ControlPrune(
topicID: topic,
peers: g.peerExchangeList(topic),
backoff: g.parameters.pruneBackoff.seconds.uint64))
else:
trace "peer grafting topic we're not interested in", peer, topic
# gossip 1.1, we do not send a control message prune anymore
return prunes
proc getPeers(
prune: ControlPrune, peer: PubSubPeer
): seq[(PeerId, Option[PeerRecord])] =
proc getPeers(prune: ControlPrune, peer: PubSubPeer): seq[(PeerId, Option[PeerRecord])] =
var routingRecords: seq[(PeerId, Option[PeerRecord])]
for record in prune.peers:
var peerRecord = none(PeerRecord)
if record.signedPeerRecord.len > 0:
SignedPeerRecord.decode(record.signedPeerRecord).toOpt().withValue(spr):
if record.peerId != spr.data.peerId:
trace "peer sent envelope with wrong public key", peer
else:
peerRecord = some(spr.data)
let peerRecord =
if record.signedPeerRecord.len == 0:
none(PeerRecord)
else:
trace "peer sent invalid SPR", peer
let signedRecord = SignedPeerRecord.decode(record.signedPeerRecord)
if signedRecord.isErr:
trace "peer sent invalid SPR", peer, error=signedRecord.error
none(PeerRecord)
else:
if record.peerId != signedRecord.get().data.peerId:
trace "peer sent envelope with wrong public key", peer
none(PeerRecord)
else:
some(signedRecord.get().data)
routingRecords.add((record.peerId, peerRecord))
routingRecords
proc handlePrune*(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) =
for prune in prunes:
let topic = prune.topicID
trace "peer pruned topicID", peer, topic
proc handlePrune*(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) {.raises: [Defect].} =
for prune in prunes:
let topic = prune.topicId
trace "peer pruned topic", peer, topic
# add peer backoff
if prune.backoff > 0:
let
# avoid overflows and clamp to reasonable value
backoffSeconds =
clamp(prune.backoff + BackoffSlackTime, 0'u64, 1.days.seconds.uint64)
backoffSeconds = clamp(
prune.backoff + BackoffSlackTime,
0'u64,
1.days.seconds.uint64
)
backoff = Moment.fromNow(backoffSeconds.int64.seconds)
current = g.backingOff.getOrDefault(topic).getOrDefault(peer.peerId)
if backoff > current:
g.backingOff.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] =
backoff
g.backingOff
.mgetOrPut(topic, initTable[PeerId, Moment]())[peer.peerId] = backoff
trace "pruning rpc received peer", peer, score = peer.score
g.pruned(peer, topic, setBackoff = false)
g.mesh.removePeer(topic, peer)
if peer.score > g.parameters.gossipThreshold and prune.peers.len > 0 and
g.routingRecordsHandler.len > 0:
g.routingRecordsHandler.len > 0:
let routingRecords = prune.getPeers(peer)
for handler in g.routingRecordsHandler:
handler(peer.peerId, topic, routingRecords)
proc handleIHave*(
g: GossipSub, peer: PubSubPeer, ihaves: seq[ControlIHave]
): ControlIWant =
proc handleIHave*(g: GossipSub,
peer: PubSubPeer,
ihaves: seq[ControlIHave]): ControlIWant {.raises: [Defect].} =
var res: ControlIWant
if peer.score < g.parameters.gossipThreshold:
trace "ihave: ignoring low score peer", peer, score = peer.score
elif peer.iHaveBudget <= 0:
trace "ihave: ignoring out of budget peer", peer, score = peer.score
else:
for ihave in ihaves:
trace "peer sent ihave", peer, topicID = ihave.topicID, msgs = ihave.messageIDs
if ihave.topicID in g.topics:
for msgId in ihave.messageIDs:
if not g.hasSeen(g.salt(msgId)):
if peer.iHaveBudget <= 0:
break
elif msgId notin res.messageIDs:
res.messageIDs.add(msgId)
# TODO review deduplicate algorithm
# * https://github.com/nim-lang/Nim/blob/5f46474555ee93306cce55342e81130c1da79a42/lib/pure/collections/sequtils.nim#L184
# * it's probably not efficient and might give preference to the first dupe
let deIhaves = ihaves.deduplicate()
for ihave in deIhaves:
trace "peer sent ihave",
peer, topic = ihave.topicId, msgs = ihave.messageIds
if ihave.topicId in g.mesh:
# also avoid duplicates here!
let deIhavesMsgs = ihave.messageIds.deduplicate()
for msgId in deIhavesMsgs:
if not g.hasSeen(msgId):
if peer.iHaveBudget > 0:
res.messageIds.add(msgId)
dec peer.iHaveBudget
trace "requested message via ihave", messageID = msgId
trace "requested message via ihave", messageID=msgId
else:
break
# shuffling res.messageIDs before sending it out to increase the likelihood
# of getting an answer if the peer truncates the list due to internal size restrictions.
g.rng.shuffle(res.messageIDs)
g.rng.shuffle(res.messageIds)
return res
proc handleIDontWant*(g: GossipSub, peer: PubSubPeer, iDontWants: seq[ControlIWant]) =
for dontWant in iDontWants:
for messageId in dontWant.messageIDs:
if peer.iDontWants[^1].len > 1000:
break
peer.iDontWants[^1].incl(g.salt(messageId))
proc handleIWant*(
g: GossipSub, peer: PubSubPeer, iwants: seq[ControlIWant]
): seq[Message] =
var
messages: seq[Message]
invalidRequests = 0
proc handleIWant*(g: GossipSub,
peer: PubSubPeer,
iwants: seq[ControlIWant]): seq[Message] {.raises: [Defect].} =
var messages: seq[Message]
if peer.score < g.parameters.gossipThreshold:
trace "iwant: ignoring low score peer", peer, score = peer.score
elif peer.iWantBudget <= 0:
trace "iwant: ignoring out of budget peer", peer, score = peer.score
else:
for iwant in iwants:
for mid in iwant.messageIDs:
let deIwants = iwants.deduplicate()
for iwant in deIwants:
let deIwantsMsgs = iwant.messageIds.deduplicate()
for mid in deIwantsMsgs:
trace "peer sent iwant", peer, messageID = mid
# canAskIWant will only return true once for a specific message
if not peer.canAskIWant(mid):
libp2p_gossipsub_received_iwants.inc(1, labelValues = ["notsent"])
invalidRequests.inc()
if invalidRequests > 20:
libp2p_gossipsub_received_iwants.inc(1, labelValues = ["skipped"])
return messages
continue
let msg = g.mcache.get(mid).valueOr:
libp2p_gossipsub_received_iwants.inc(1, labelValues = ["unknown"])
continue
libp2p_gossipsub_received_iwants.inc(1, labelValues = ["correct"])
messages.add(msg)
let msg = g.mcache.get(mid)
if msg.isSome:
libp2p_gossipsub_mcache_hit.observe(1)
# avoid spam
if peer.iWantBudget > 0:
messages.add(msg.get())
dec peer.iWantBudget
else:
break
else:
libp2p_gossipsub_mcache_hit.observe(0)
return messages
proc commitMetrics(metrics: var MeshMetrics) =
proc commitMetrics(metrics: var MeshMetrics) {.raises: [Defect].} =
libp2p_gossipsub_low_peers_topics.set(metrics.lowPeersTopics)
libp2p_gossipsub_no_peers_topics.set(metrics.noPeersTopics)
libp2p_gossipsub_under_dout_topics.set(metrics.underDoutTopics)
libp2p_gossipsub_healthy_peers_topics.set(metrics.healthyPeersTopics)
libp2p_gossipsub_peers_per_topic_gossipsub.set(
metrics.otherPeersPerTopicGossipsub, labelValues = ["other"]
)
libp2p_gossipsub_peers_per_topic_fanout.set(
metrics.otherPeersPerTopicFanout, labelValues = ["other"]
)
libp2p_gossipsub_peers_per_topic_mesh.set(
metrics.otherPeersPerTopicMesh, labelValues = ["other"]
)
libp2p_gossipsub_peers_per_topic_gossipsub.set(metrics.otherPeersPerTopicGossipsub, labelValues = ["other"])
libp2p_gossipsub_peers_per_topic_fanout.set(metrics.otherPeersPerTopicFanout, labelValues = ["other"])
libp2p_gossipsub_peers_per_topic_mesh.set(metrics.otherPeersPerTopicMesh, labelValues = ["other"])
proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil) =
proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil) {.raises: [Defect].} =
logScope:
topic
mesh = g.mesh.peers(topic)
@@ -366,25 +318,25 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
var
prunes, grafts: seq[PubSubPeer]
npeers = g.mesh.peers(topic)
nOutPeers = g.mesh.outboundPeers(topic)
defaultMesh: HashSet[PubSubPeer]
backingOff = g.backingOff.getOrDefault(topic)
if npeers < g.parameters.dLow:
if npeers < g.parameters.dLow:
trace "replenishing mesh", peers = npeers
# replenish the mesh if we're below Dlo
var
candidates: seq[PubSubPeer]
currentMesh = addr defaultMesh
g.mesh.withValue(topic, v):
currentMesh = v
g.mesh.withValue(topic, v): currentMesh = v
g.gossipsub.withValue(topic, peerList):
for it in peerList[]:
if it.connected and
if
it.connected and
# avoid negative score peers
it.score >= 0.0 and it notin currentMesh[] and
# don't pick direct peers
it.score >= 0.0 and
it notin currentMesh[] and
# don't pick explicit peers
it.peerId notin g.parameters.directPeers and
# and avoid peers we are backing off
it.peerId notin backingOff:
@@ -407,22 +359,24 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
g.grafted(peer, topic)
g.fanout.removePeer(topic, peer)
grafts &= peer
elif nOutPeers < g.parameters.dOut:
else:
trace "replenishing mesh outbound quota", peers = g.mesh.peers(topic)
var
candidates: seq[PubSubPeer]
currentMesh = addr defaultMesh
g.mesh.withValue(topic, v):
currentMesh = v
g.mesh.withValue(topic, v): currentMesh = v
g.gossipsub.withValue(topic, peerList):
for it in peerList[]:
if it.connected and
# get only outbound ones
it.outbound and it notin currentMesh[] and
if
it.connected and
# get only outbound ones
it.outbound and
it notin currentMesh[] and
# avoid negative score peers
it.score >= 0.0 and
# don't pick direct peers
# don't pick explicit peers
it.peerId notin g.parameters.directPeers and
# and avoid peers we are backing off
it.peerId notin backingOff:
@@ -434,8 +388,8 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
# sort peers by score, high score first, we are grafting
candidates.sort(byScore, SortOrder.Descending)
# Graft outgoing peers so we reach a count of dOut
candidates.setLen(min(candidates.len, g.parameters.dOut - nOutPeers))
# Graft peers so we reach a count of D
candidates.setLen(min(candidates.len, g.parameters.dOut))
trace "grafting outbound peers", topic, peers = candidates.len
@@ -445,6 +399,7 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
g.fanout.removePeer(topic, peer)
grafts &= peer
# get again npeers after possible grafts
npeers = g.mesh.peers(topic)
if npeers > g.parameters.dHigh:
@@ -455,15 +410,9 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
libp2p_gossipsub_above_dhigh_condition.inc(labelValues = ["other"])
# prune peers if we've gone over Dhi
prunes = toSeq(
try:
g.mesh[topic]
except KeyError:
raiseAssert "have peers"
)
prunes = toSeq(try: g.mesh[topic] except KeyError: raiseAssert "have peers")
# avoid pruning peers we are currently grafting in this heartbeat
prunes.keepIf do(x: PubSubPeer) -> bool:
x notin grafts
prunes.keepIf do (x: PubSubPeer) -> bool: x notin grafts
# shuffle anyway, score might be not used
g.rng.shuffle(prunes)
@@ -511,12 +460,7 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
# opportunistic grafting, by spec mesh should not be empty...
if g.mesh.peers(topic) > 1:
var peers = toSeq(
try:
g.mesh[topic]
except KeyError:
raiseAssert "have peers"
)
var peers = toSeq(try: g.mesh[topic] except KeyError: raiseAssert "have peers")
# grafting so high score has priority
peers.sort(byScore, SortOrder.Descending)
let medianIdx = peers.len div 2
@@ -527,13 +471,14 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
var
avail: seq[PubSubPeer]
currentMesh = addr defaultMesh
g.mesh.withValue(topic, v):
currentMesh = v
g.mesh.withValue(topic, v): currentMesh = v
g.gossipsub.withValue(topic, peerList):
for it in peerList[]:
if it.score >= median.score and # avoid negative score peers
it notin currentMesh[] and
# don't pick direct peers
if
# avoid negative score peers
it.score >= median.score and
it notin currentMesh[] and
# don't pick explicit peers
it.peerId notin g.parameters.directPeers and
# and avoid peers we are backing off
it.peerId notin backingOff:
@@ -559,23 +504,17 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
inc metrics[].healthyPeersTopics
var meshPeers = toSeq(g.mesh.getOrDefault(topic, initHashSet[PubSubPeer]()))
meshPeers.keepIf do(x: PubSubPeer) -> bool:
x.outbound
meshPeers.keepIf do (x: PubSubPeer) -> bool: x.outbound
if meshPeers.len < g.parameters.dOut:
inc metrics[].underDoutTopics
if g.knownTopics.contains(topic):
libp2p_gossipsub_peers_per_topic_gossipsub.set(
g.gossipsub.peers(topic).int64, labelValues = [topic]
)
libp2p_gossipsub_peers_per_topic_fanout.set(
g.fanout.peers(topic).int64, labelValues = [topic]
)
libp2p_gossipsub_peers_per_topic_mesh.set(
g.mesh.peers(topic).int64, labelValues = [topic]
)
libp2p_gossipsub_peers_per_topic_gossipsub
.set(g.gossipsub.peers(topic).int64, labelValues = [topic])
libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.peers(topic).int64, labelValues = [topic])
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh.peers(topic).int64, labelValues = [topic])
else:
metrics[].otherPeersPerTopicGossipsub += g.gossipsub.peers(topic).int64
metrics[].otherPeersPerTopicFanout += g.fanout.peers(topic).int64
@@ -585,27 +524,17 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
# Send changes to peers after table updates to avoid stale state
if grafts.len > 0:
let graft =
RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))
g.broadcast(grafts, graft, isHighPriority = true)
let graft = RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))
g.broadcast(grafts, graft)
if prunes.len > 0:
let prune = RPCMsg(
control: some(
ControlMessage(
prune:
@[
ControlPrune(
topicID: topic,
peers: g.peerExchangeList(topic),
backoff: g.parameters.pruneBackoff.seconds.uint64,
)
]
)
)
)
g.broadcast(prunes, prune, isHighPriority = true)
let prune = RPCMsg(control: some(ControlMessage(
prune: @[ControlPrune(
topicID: topic,
peers: g.peerExchangeList(topic),
backoff: g.parameters.pruneBackoff.seconds.uint64)])))
g.broadcast(prunes, prune)
proc dropFanoutPeers*(g: GossipSub) =
proc dropFanoutPeers*(g: GossipSub) {.raises: [Defect].} =
# drop peers that we haven't published to in
# GossipSubFanoutTTL seconds
let now = Moment.now()
@@ -618,25 +547,23 @@ proc dropFanoutPeers*(g: GossipSub) =
for topic in drops:
g.lastFanoutPubSub.del topic
proc replenishFanout*(g: GossipSub, topic: string) =
proc replenishFanout*(g: GossipSub, topic: string) {.raises: [Defect].} =
## get fanout peers for a topic
logScope:
topic
logScope: topic
trace "about to replenish fanout"
let currentMesh = g.mesh.getOrDefault(topic)
if g.fanout.peers(topic) < g.parameters.dLow:
let currentMesh = g.mesh.getOrDefault(topic)
trace "replenishing fanout", peers = g.fanout.peers(topic)
for peer in g.gossipsub.getOrDefault(topic):
if peer in currentMesh:
continue
if peer in currentMesh: continue
if g.fanout.addPeer(topic, peer):
if g.fanout.peers(topic) == g.parameters.d:
break
trace "fanout replenished with peers", peers = g.fanout.peers(topic)
proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] =
proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] {.raises: [Defect].} =
## gossip iHave messages to peers
##
@@ -644,14 +571,14 @@ proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] =
var control: Table[PubSubPeer, ControlMessage]
let topics = toHashSet(toSeq(g.mesh.keys)) + toHashSet(toSeq(g.fanout.keys))
trace "getting gossip peers (iHave)", ntopics = topics.len
trace "getting gossip peers (iHave)", ntopics=topics.len
for topic in topics:
if topic notin g.gossipsub:
trace "topic not in gossip array, skipping", topic = topic
trace "topic not in gossip array, skipping", topicID = topic
continue
let mids = g.mcache.window(topic)
if not (mids.len > 0):
if not(mids.len > 0):
trace "no messages to emit"
continue
@@ -659,7 +586,7 @@ proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] =
cacheWindowSize += midsSeq.len
trace "got messages to emit", size = midsSeq.len
trace "got messages to emit", size=midsSeq.len
# not in spec
# similar to rust: https://github.com/sigp/rust-libp2p/blob/f53d02bc873fef2bf52cd31e3d5ce366a41d8a8c/protocols/gossipsub/src/behaviour.rs#L2101
@@ -675,14 +602,15 @@ proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] =
gossipPeers = mesh + fanout
var allPeers = toSeq(g.gossipsub.getOrDefault(topic))
allPeers.keepIf do(x: PubSubPeer) -> bool:
x.peerId notin g.parameters.directPeers and x notin gossipPeers and
x.score >= g.parameters.gossipThreshold
allPeers.keepIf do (x: PubSubPeer) -> bool:
x.peerId notin g.parameters.directPeers and
x notin gossipPeers and
x.score >= g.parameters.gossipThreshold
# https://github.com/libp2p/specs/blob/98c5aa9421703fc31b0833ad8860a55db15be063/pubsub/gossipsub/gossipsub-v1.1.md#adaptive-gossip-dissemination
let
factor = (g.parameters.gossipFactor.float * allPeers.len.float).int
target = max(g.parameters.dLazy, factor)
var target = g.parameters.dLazy
let factor = (g.parameters.gossipFactor.float * allPeers.len.float).int
if factor > target:
target = min(factor, allPeers.len)
if target < allPeers.len:
g.rng.shuffle(allPeers)
@@ -690,85 +618,70 @@ proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] =
for peer in allPeers:
control.mgetOrPut(peer, ControlMessage()).ihave.add(ihave)
for msgId in ihave.messageIDs:
peer.sentIHaves[^1].incl(msgId)
libp2p_gossipsub_cache_window_size.set(cacheWindowSize.int64)
return control
proc onHeartbeat(g: GossipSub) =
# reset IWANT budget
# reset IHAVE cap
block:
for peer in g.peers.values:
peer.sentIHaves.addFirst(default(HashSet[MessageId]))
if peer.sentIHaves.len > g.parameters.historyLength:
discard peer.sentIHaves.popLast()
peer.iDontWants.addFirst(default(HashSet[SaltedId]))
if peer.iDontWants.len > g.parameters.historyLength:
discard peer.iDontWants.popLast()
peer.iHaveBudget = IHavePeerBudget
peer.pingBudget = PingsPeerBudget
var meshMetrics = MeshMetrics()
for t in toSeq(g.topics.keys):
# remove expired backoffs
proc onHeartbeat(g: GossipSub) {.raises: [Defect].} =
# reset IWANT budget
# reset IHAVE cap
block:
handleBackingOff(g.backingOff, t)
for peer in g.peers.values:
peer.iWantBudget = IWantPeerBudget
peer.iHaveBudget = IHavePeerBudget
# prune every negative score peer
# do this before relance
# in order to avoid grafted -> pruned in the same cycle
let meshPeers = g.mesh.getOrDefault(t)
var prunes: seq[PubSubPeer]
for peer in meshPeers:
if peer.score < 0.0:
trace "pruning negative score peer", peer, score = peer.score
g.pruned(peer, t)
g.mesh.removePeer(t, peer)
prunes &= peer
if prunes.len > 0:
let prune = RPCMsg(
control: some(
ControlMessage(
prune:
@[
ControlPrune(
topicID: t,
peers: g.peerExchangeList(t),
backoff: g.parameters.pruneBackoff.seconds.uint64,
)
]
)
)
)
g.broadcast(prunes, prune, isHighPriority = true)
var meshMetrics = MeshMetrics()
# pass by ptr in order to both signal we want to update metrics
# and as well update the struct for each topic during this iteration
g.rebalanceMesh(t, addr meshMetrics)
for t in toSeq(g.topics.keys):
# remove expired backoffs
block:
handleBackingOff(g.backingOff, t)
commitMetrics(meshMetrics)
# prune every negative score peer
# do this before relance
# in order to avoid grafted -> pruned in the same cycle
let meshPeers = g.mesh.getOrDefault(t)
var prunes: seq[PubSubPeer]
for peer in meshPeers:
if peer.score < 0.0:
trace "pruning negative score peer", peer, score = peer.score
g.pruned(peer, t)
g.mesh.removePeer(t, peer)
prunes &= peer
if prunes.len > 0:
let prune = RPCMsg(control: some(ControlMessage(
prune: @[ControlPrune(
topicID: t,
peers: g.peerExchangeList(t),
backoff: g.parameters.pruneBackoff.seconds.uint64)])))
g.broadcast(prunes, prune)
g.dropFanoutPeers()
# pass by ptr in order to both signal we want to update metrics
# and as well update the struct for each topic during this iteration
g.rebalanceMesh(t, addr meshMetrics)
# replenish known topics to the fanout
for t in toSeq(g.fanout.keys):
g.replenishFanout(t)
commitMetrics(meshMetrics)
let peers = g.getGossipPeers()
for peer, control in peers:
# only ihave from here
for ihave in control.ihave:
if g.knownTopics.contains(ihave.topicID):
libp2p_pubsub_broadcast_ihave.inc(labelValues = [ihave.topicID])
else:
libp2p_pubsub_broadcast_ihave.inc(labelValues = ["generic"])
g.send(peer, RPCMsg(control: some(control)), isHighPriority = true)
g.dropFanoutPeers()
g.mcache.shift() # shift the cache
# replenish known topics to the fanout
for t in toSeq(g.fanout.keys):
g.replenishFanout(t)
let peers = g.getGossipPeers()
for peer, control in peers:
# only ihave from here
for ihave in control.ihave:
if g.knownTopics.contains(ihave.topicId):
libp2p_pubsub_broadcast_ihave.inc(labelValues = [ihave.topicId])
else:
libp2p_pubsub_broadcast_ihave.inc(labelValues = ["generic"])
g.send(peer, RPCMsg(control: some(control)))
g.mcache.shift() # shift the cache
# {.pop.} # raises [Defect]
proc heartbeat*(g: GossipSub) {.async.} =
heartbeat "GossipSub", g.parameters.heartbeatInterval:

View File

@@ -7,70 +7,29 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[tables, sets]
import std/[tables, sets, options]
import chronos, chronicles, metrics
import chronos/ratelimit
import "."/[types]
import ".."/[pubsubpeer]
import ../rpc/messages
import "../../.."/[peerid, multiaddress, switch, utils/heartbeat]
import ../pubsub
import "../../.."/[peerid, multiaddress, utility, switch, utils/heartbeat]
logScope:
topics = "libp2p gossipsub"
declareGauge(
libp2p_gossipsub_peers_scores,
"the scores of the peers in gossipsub",
labels = ["agent"],
)
declareCounter(
libp2p_gossipsub_bad_score_disconnection,
"the number of peers disconnected by gossipsub",
labels = ["agent"],
)
declareGauge(
libp2p_gossipsub_peers_score_firstMessageDeliveries,
"Detailed gossipsub scoring metric",
labels = ["agent"],
)
declareGauge(
libp2p_gossipsub_peers_score_meshMessageDeliveries,
"Detailed gossipsub scoring metric",
labels = ["agent"],
)
declareGauge(
libp2p_gossipsub_peers_score_meshFailurePenalty,
"Detailed gossipsub scoring metric",
labels = ["agent"],
)
declareGauge(
libp2p_gossipsub_peers_score_invalidMessageDeliveries,
"Detailed gossipsub scoring metric",
labels = ["agent"],
)
declareGauge(
libp2p_gossipsub_peers_score_appScore,
"Detailed gossipsub scoring metric",
labels = ["agent"],
)
declareGauge(
libp2p_gossipsub_peers_score_behaviourPenalty,
"Detailed gossipsub scoring metric",
labels = ["agent"],
)
declareGauge(
libp2p_gossipsub_peers_score_colocationFactor,
"Detailed gossipsub scoring metric",
labels = ["agent"],
)
declarePublicCounter(
libp2p_gossipsub_peers_rate_limit_hits,
"The number of times peers were above their rate limit",
labels = ["agent"],
)
declareGauge(libp2p_gossipsub_peers_scores, "the scores of the peers in gossipsub", labels = ["agent"])
declareCounter(libp2p_gossipsub_bad_score_disconnection, "the number of peers disconnected by gossipsub", labels = ["agent"])
declareGauge(libp2p_gossipsub_peers_score_firstMessageDeliveries, "Detailed gossipsub scoring metric", labels = ["agent"])
declareGauge(libp2p_gossipsub_peers_score_meshMessageDeliveries, "Detailed gossipsub scoring metric", labels = ["agent"])
declareGauge(libp2p_gossipsub_peers_score_meshFailurePenalty, "Detailed gossipsub scoring metric", labels = ["agent"])
declareGauge(libp2p_gossipsub_peers_score_invalidMessageDeliveries, "Detailed gossipsub scoring metric", labels = ["agent"])
declareGauge(libp2p_gossipsub_peers_score_appScore, "Detailed gossipsub scoring metric", labels = ["agent"])
declareGauge(libp2p_gossipsub_peers_score_behaviourPenalty, "Detailed gossipsub scoring metric", labels = ["agent"])
declareGauge(libp2p_gossipsub_peers_score_colocationFactor, "Detailed gossipsub scoring metric", labels = ["agent"])
proc init*(_: type[TopicParams]): TopicParams =
TopicParams(
@@ -90,24 +49,21 @@ proc init*(_: type[TopicParams]): TopicParams =
meshFailurePenaltyWeight: -1.0,
meshFailurePenaltyDecay: 0.5,
invalidMessageDeliveriesWeight: -1.0,
invalidMessageDeliveriesDecay: 0.5,
invalidMessageDeliveriesDecay: 0.5
)
proc withPeerStats*(
g: GossipSub,
peerId: PeerId,
action: proc(stats: var PeerStats) {.gcsafe, raises: [].},
) =
action: proc (stats: var PeerStats) {.gcsafe, raises: [Defect].}) =
## Add or update peer statistics for a particular peer id - the statistics
## are retained across multiple connections until they expire
g.peerStats.withValue(peerId, stats):
g.peerStats.withValue(peerId, stats) do:
action(stats[])
do:
action(
g.peerStats.mgetOrPut(
peerId, PeerStats(expire: Moment.now() + g.parameters.retainScore)
)
)
action(g.peerStats.mgetOrPut(peerId, PeerStats(
expire: Moment.now() + g.parameters.retainScore
)))
func `/`(a, b: Duration): float64 =
let
@@ -115,34 +71,42 @@ func `/`(a, b: Duration): float64 =
fb = float64(b.nanoseconds)
fa / fb
func byScore*(x, y: PubSubPeer): int =
system.cmp(x.score, y.score)
func byScore*(x,y: PubSubPeer): int = system.cmp(x.score, y.score)
proc colocationFactor(g: GossipSub, peer: PubSubPeer): float64 =
let address = peer.address.valueOr:
return 0.0
g.peersInIP.mgetOrPut(address, initHashSet[PeerId]()).incl(peer.peerId)
let ipPeers = g.peersInIP.getOrDefault(address).len().float64
if ipPeers > g.parameters.ipColocationFactorThreshold:
trace "colocationFactor over threshold", peer, address, ipPeers
let over = ipPeers - g.parameters.ipColocationFactorThreshold
over * over
else:
if peer.address.isNone():
0.0
else:
let
address = peer.address.get()
g.peersInIP.mgetOrPut(address, initHashSet[PeerId]()).incl(peer.peerId)
let
ipPeers = g.peersInIP.getOrDefault(address).len().float64
if ipPeers > g.parameters.ipColocationFactorThreshold:
trace "colocationFactor over threshold", peer, address, ipPeers
let over = ipPeers - g.parameters.ipColocationFactorThreshold
over * over
else:
0.0
{.pop.}
proc disconnectPeer(g: GossipSub, peer: PubSubPeer) {.async.} =
let agent =
when defined(libp2p_agents_metrics):
if peer.shortAgent.len > 0:
peer.shortAgent
else:
"unknown"
else:
"unknown"
libp2p_gossipsub_bad_score_disconnection.inc(labelValues = [agent])
proc disconnectPeer*(g: GossipSub, peer: PubSubPeer) {.async.} =
try:
await g.switch.disconnect(peer.peerId)
except CatchableError as exc: # Never cancelled
trace "Failed to close connection", peer, error = exc.name, msg = exc.msg
proc disconnectIfBadScorePeer*(g: GossipSub, peer: PubSubPeer, score: float64) =
if g.parameters.disconnectBadPeers and score < g.parameters.graylistThreshold and
peer.peerId notin g.parameters.directPeers:
debug "disconnecting bad score peer", peer, score = peer.score
asyncSpawn(g.disconnectPeer(peer))
libp2p_gossipsub_bad_score_disconnection.inc(labelValues = [peer.getAgent()])
proc updateScores*(g: GossipSub) = # avoid async
## https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.1.md#the-score-function
@@ -154,7 +118,7 @@ proc updateScores*(g: GossipSub) = # avoid async
for peerId, stats in g.peerStats.mpairs:
let peer = g.peers.getOrDefault(peerId)
if isNil(peer) or not (peer.connected):
if isNil(peer) or not(peer.connected):
if now > stats.expire:
evicting.add(peerId)
trace "evicted peer from memory", peer = peerId
@@ -165,7 +129,7 @@ proc updateScores*(g: GossipSub) = # avoid async
var
n_topics = 0
is_grafted = 0
scoreAcc = 0.0 # accumulates the peer score
score = 0.0
# Per topic
for topic, topicParams in g.topicParams:
@@ -191,14 +155,12 @@ proc updateScores*(g: GossipSub) = # avoid async
else:
info.meshMessageDeliveriesActive = false
topicScore +=
info.firstMessageDeliveries * topicParams.firstMessageDeliveriesWeight
topicScore += info.firstMessageDeliveries * topicParams.firstMessageDeliveriesWeight
trace "p2", peer, p2 = info.firstMessageDeliveries, topic, topicScore
if info.meshMessageDeliveriesActive:
if info.meshMessageDeliveries < topicParams.meshMessageDeliveriesThreshold:
let deficit =
topicParams.meshMessageDeliveriesThreshold - info.meshMessageDeliveries
let deficit = topicParams.meshMessageDeliveriesThreshold - info.meshMessageDeliveries
let p3 = deficit * deficit
trace "p3", peer, p3, topic, topicScore
topicScore += p3 * topicParams.meshMessageDeliveriesWeight
@@ -206,34 +168,26 @@ proc updateScores*(g: GossipSub) = # avoid async
topicScore += info.meshFailurePenalty * topicParams.meshFailurePenaltyWeight
trace "p3b", peer, p3b = info.meshFailurePenalty, topic, topicScore
topicScore +=
info.invalidMessageDeliveries * info.invalidMessageDeliveries *
topicParams.invalidMessageDeliveriesWeight
trace "p4",
peer,
p4 = info.invalidMessageDeliveries * info.invalidMessageDeliveries,
topic,
topicScore
topicScore += info.invalidMessageDeliveries * info.invalidMessageDeliveries * topicParams.invalidMessageDeliveriesWeight
trace "p4", p4 = info.invalidMessageDeliveries * info.invalidMessageDeliveries, topic, topicScore
scoreAcc += topicScore * topicParams.topicWeight
trace "updated peer topic's scores", peer, topic, info, topicScore
trace "updated peer topic's scores",
peer, scoreAcc, topic, info, topicScore, topicWeight = topicParams.topicWeight
score += topicScore * topicParams.topicWeight
# Score metrics
let agent = peer.getAgent()
libp2p_gossipsub_peers_score_firstMessageDeliveries.inc(
info.firstMessageDeliveries, labelValues = [agent]
)
libp2p_gossipsub_peers_score_meshMessageDeliveries.inc(
info.meshMessageDeliveries, labelValues = [agent]
)
libp2p_gossipsub_peers_score_meshFailurePenalty.inc(
info.meshFailurePenalty, labelValues = [agent]
)
libp2p_gossipsub_peers_score_invalidMessageDeliveries.inc(
info.invalidMessageDeliveries, labelValues = [agent]
)
let agent =
when defined(libp2p_agents_metrics):
if peer.shortAgent.len > 0:
peer.shortAgent
else:
"unknown"
else:
"unknown"
libp2p_gossipsub_peers_score_firstMessageDeliveries.inc(info.firstMessageDeliveries, labelValues = [agent])
libp2p_gossipsub_peers_score_meshMessageDeliveries.inc(info.meshMessageDeliveries, labelValues = [agent])
libp2p_gossipsub_peers_score_meshFailurePenalty.inc(info.meshFailurePenalty, labelValues = [agent])
libp2p_gossipsub_peers_score_invalidMessageDeliveries.inc(info.invalidMessageDeliveries, labelValues = [agent])
# Score decay
info.firstMessageDeliveries *= topicParams.firstMessageDeliveriesDecay
@@ -256,45 +210,34 @@ proc updateScores*(g: GossipSub) = # avoid async
# commit our changes, mgetOrPut does NOT work as wanted with value types (lent?)
stats.topicInfos[topic] = info
scoreAcc += peer.appScore * g.parameters.appSpecificWeight
trace "appScore",
peer,
scoreAcc,
appScore = peer.appScore,
appSpecificWeight = g.parameters.appSpecificWeight
score += peer.appScore * g.parameters.appSpecificWeight
# The value of the parameter is the square of the counter and is mixed with a negative weight.
scoreAcc +=
peer.behaviourPenalty * peer.behaviourPenalty * g.parameters.behaviourPenaltyWeight
trace "behaviourPenalty",
peer,
scoreAcc,
behaviourPenalty = peer.behaviourPenalty,
behaviourPenaltyWeight = g.parameters.behaviourPenaltyWeight
score += peer.behaviourPenalty * peer.behaviourPenalty * g.parameters.behaviourPenaltyWeight
let colocationFactor = g.colocationFactor(peer)
scoreAcc += colocationFactor * g.parameters.ipColocationFactorWeight
trace "colocationFactor",
peer,
scoreAcc,
colocationFactor,
ipColocationFactorWeight = g.parameters.ipColocationFactorWeight
score += colocationFactor * g.parameters.ipColocationFactorWeight
# Score metrics
let agent = peer.getAgent()
let agent =
when defined(libp2p_agents_metrics):
if peer.shortAgent.len > 0:
peer.shortAgent
else:
"unknown"
else:
"unknown"
libp2p_gossipsub_peers_score_appScore.inc(peer.appScore, labelValues = [agent])
libp2p_gossipsub_peers_score_behaviourPenalty.inc(
peer.behaviourPenalty, labelValues = [agent]
)
libp2p_gossipsub_peers_score_colocationFactor.inc(
colocationFactor, labelValues = [agent]
)
libp2p_gossipsub_peers_score_behaviourPenalty.inc(peer.behaviourPenalty, labelValues = [agent])
libp2p_gossipsub_peers_score_colocationFactor.inc(colocationFactor, labelValues = [agent])
# decay behaviourPenalty
peer.behaviourPenalty *= g.parameters.behaviourPenaltyDecay
if peer.behaviourPenalty < g.parameters.decayToZero:
peer.behaviourPenalty = 0
peer.score = scoreAcc
peer.score = score
# copy into stats the score to keep until expired
stats.score = peer.score
@@ -302,10 +245,13 @@ proc updateScores*(g: GossipSub) = # avoid async
stats.behaviourPenalty = peer.behaviourPenalty
stats.expire = now + g.parameters.retainScore # refresh expiration
trace "updated (accumulated) peer's score",
peer, peerScore = peer.score, n_topics, is_grafted
trace "updated peer's score", peer, score = peer.score, n_topics, is_grafted
if g.parameters.disconnectBadPeers and stats.score < g.parameters.graylistThreshold and
peer.peerId notin g.parameters.directPeers:
debug "disconnecting bad score peer", peer, score = peer.score
asyncSpawn(try: g.disconnectPeer(peer) except Exception as exc: raiseAssert exc.msg)
g.disconnectIfBadScorePeer(peer, stats.score)
libp2p_gossipsub_peers_scores.inc(peer.score, labelValues = [agent])
for peer in evicting:
@@ -318,54 +264,43 @@ proc scoringHeartbeat*(g: GossipSub) {.async.} =
trace "running scoring heartbeat", instance = cast[int](g)
g.updateScores()
proc punishInvalidMessage*(g: GossipSub, peer: PubSubPeer, msg: Message) {.async.} =
let uselessAppBytesNum = msg.data.len
peer.overheadRateLimitOpt.withValue(overheadRateLimit):
if not overheadRateLimit.tryConsume(uselessAppBytesNum):
debug "Peer sent invalid message and it's above rate limit",
peer, uselessAppBytesNum
libp2p_gossipsub_peers_rate_limit_hits.inc(labelValues = [peer.getAgent()])
# let's just measure at the beginning for test purposes.
if g.parameters.disconnectPeerAboveRateLimit:
await g.disconnectPeer(peer)
raise newException(
PeerRateLimitError, "Peer disconnected because it's above rate limit."
)
proc punishInvalidMessage*(g: GossipSub, peer: PubSubPeer, topics: seq[string]) =
for tt in topics:
let t = tt
if t notin g.topics:
continue
let topic = msg.topic
if topic notin g.topics:
return
# update stats
g.withPeerStats(peer.peerId) do(stats: var PeerStats):
stats.topicInfos.mgetOrPut(topic, TopicInfo()).invalidMessageDeliveries += 1
let tt = t
# update stats
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
stats.topicInfos.mgetOrPut(tt, TopicInfo()).invalidMessageDeliveries += 1
proc addCapped*[T](stat: var T, diff, cap: T) =
stat += min(diff, cap - stat)
proc rewardDelivered*(
g: GossipSub, peer: PubSubPeer, topic: string, first: bool, delay = ZeroDuration
) =
if topic notin g.topics:
return
g: GossipSub, peer: PubSubPeer, topics: openArray[string], first: bool, delay = ZeroDuration) =
for tt in topics:
let t = tt
if t notin g.topics:
continue
let topicParams = g.topicParams.mgetOrPut(topic, TopicParams.init())
# if in mesh add more delivery score
let tt = t
let topicParams = g.topicParams.mgetOrPut(t, TopicParams.init())
# if in mesh add more delivery score
if delay > topicParams.meshMessageDeliveriesWindow:
# Too old
return
if delay > topicParams.meshMessageDeliveriesWindow:
# Too old
continue
g.withPeerStats(peer.peerId) do(stats: var PeerStats):
stats.topicInfos.withValue(topic, tstats):
if first:
tstats[].firstMessageDeliveries.addCapped(
1, topicParams.firstMessageDeliveriesCap
)
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
stats.topicInfos.withValue(tt, tstats):
if tstats[].inMesh:
if first:
tstats[].firstMessageDeliveries.addCapped(
1, topicParams.firstMessageDeliveriesCap)
if tstats[].inMesh:
tstats[].meshMessageDeliveries.addCapped(
1, topicParams.meshMessageDeliveriesCap
)
do:
stats.topicInfos[topic] = TopicInfo(meshMessageDeliveries: 1)
tstats[].meshMessageDeliveries.addCapped(
1, topicParams.meshMessageDeliveriesCap)
do: # make sure we don't loose this information
stats.topicInfos[tt] = TopicInfo(meshMessageDeliveries: 1)

View File

@@ -7,19 +7,19 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import chronos
import std/[options, tables, sets]
import std/[tables, sets]
import ".."/[floodsub, peertable, mcache, pubsubpeer]
import "../rpc"/[messages]
import "../../.."/[peerid, multiaddress, utility]
export options, tables, sets
const
GossipSubCodec_12* = "/meshsub/1.2.0"
GossipSubCodec_11* = "/meshsub/1.1.0"
GossipSubCodec* = "/meshsub/1.1.0"
GossipSubCodec_10* = "/meshsub/1.0.0"
# overlay parameters
@@ -33,18 +33,20 @@ const
GossipSubHistoryLength* = 5
GossipSubHistoryGossip* = 3
# heartbeat interval
# heartbeat interval
GossipSubHeartbeatInterval* = 1.seconds
# fanout ttl
const GossipSubFanoutTTL* = 1.minutes
const
GossipSubFanoutTTL* = 1.minutes
# gossip parameters
const GossipBackoffPeriod* = 1.minutes
const
GossipBackoffPeriod* = 1.minutes
const
BackoffSlackTime* = 2 # seconds
PingsPeerBudget* = 100 # maximum of 6.4kb/heartbeat (6.4kb/s with default 1 second/hb)
IWantPeerBudget* = 25 # 25 messages per second ( reset every heartbeat )
IHavePeerBudget* = 10
# the max amount of IHave to expose, not by spec, but go as example
# rust sigp: https://github.com/sigp/rust-libp2p/blob/f53d02bc873fef2bf52cd31e3d5ce366a41d8a8c/protocols/gossipsub/src/config.rs#L572
@@ -52,7 +54,8 @@ const
IHaveMaxLength* = 5000
type
TopicInfo* = object # gossip 1.1 related
TopicInfo* = object
# gossip 1.1 related
graftTime*: Moment
meshTime*: Duration
inMesh*: bool
@@ -100,11 +103,6 @@ type
behaviourPenalty*: float64 # the eventual penalty score
GossipSubParams* {.public.} = object
# explicit is used to check if the GossipSubParams instance was created by the user either passing params to GossipSubParams(...)
# or GossipSubParams.init(...). In the first case explicit should be set to true when calling the Nim constructor.
# In the second case, the param isn't necessary and should be always be set to true by init.
# If none of those options were used, it means the instance was created using Nim default values.
# In this case, GossipSubParams.init() should be called when initing GossipSub to set the values to their default value defined by nim-libp2p.
explicit*: bool
pruneBackoff*: Duration
unsubscribeBackoff*: Duration
@@ -145,38 +143,29 @@ type
disconnectBadPeers*: bool
enablePX*: bool
bandwidthEstimatebps*: int
# This is currently used only for limting flood publishing. 0 disables flood-limiting completely
overheadRateLimit*: Opt[tuple[bytes: int, interval: Duration]]
disconnectPeerAboveRateLimit*: bool
# Max number of elements allowed in the non-priority queue. When this limit has been reached, the peer will be disconnected.
maxNumElementsInNonPriorityQueue*: int
BackoffTable* = Table[string, Table[PeerId, Moment]]
ValidationSeenTable* = Table[SaltedId, HashSet[PubSubPeer]]
ValidationSeenTable* = Table[MessageId, HashSet[PubSubPeer]]
RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]]
RoutingRecordsHandler* = proc(
peer: PeerId,
RoutingRecordsHandler* =
proc(peer: PeerId,
tag: string, # For gossipsub, the topic
peers: seq[RoutingRecordsPair],
) {.gcsafe, raises: [].}
peers: seq[RoutingRecordsPair])
{.gcsafe, raises: [Defect].}
GossipSub* = ref object of FloodSub
mesh*: PeerTable # peers that we send messages to when we are subscribed to the topic
fanout*: PeerTable
# peers that we send messages to when we're not subscribed to the topic
gossipsub*: PeerTable # peers that are subscribed to a topic
subscribedDirectPeers*: PeerTable # directpeers that we keep alive
backingOff*: BackoffTable # peers to backoff from when replenishing the mesh
lastFanoutPubSub*: Table[string, Moment] # last publish time for fanout topics
mcache*: MCache # messages cache
validationSeen*: ValidationSeenTable # peers who sent us message in validation
heartbeatFut*: Future[void] # cancellation future for heartbeat interval
scoringHeartbeatFut*: Future[void]
# cancellation future for scoring heartbeat interval
mesh*: PeerTable # peers that we send messages to when we are subscribed to the topic
fanout*: PeerTable # peers that we send messages to when we're not subscribed to the topic
gossipsub*: PeerTable # peers that are subscribed to a topic
explicit*: PeerTable # directpeers that we keep alive explicitly
backingOff*: BackoffTable # peers to backoff from when replenishing the mesh
lastFanoutPubSub*: Table[string, Moment] # last publish time for fanout topics
gossip*: Table[string, seq[ControlIHave]] # pending gossip
control*: Table[string, ControlMessage] # pending control messages
mcache*: MCache # messages cache
validationSeen*: ValidationSeenTable # peers who sent us message in validation
heartbeatFut*: Future[void] # cancellation future for heartbeat interval
scoringHeartbeatFut*: Future[void] # cancellation future for scoring heartbeat interval
heartbeatRunning*: bool
peerStats*: Table[PeerId, PeerStats]
@@ -188,7 +177,8 @@ type
heartbeatEvents*: seq[AsyncEvent]
MeshMetrics* = object # scratch buffers for metrics
MeshMetrics* = object
# scratch buffers for metrics
otherPeersPerTopicMesh*: int64
otherPeersPerTopicFanout*: int64
otherPeersPerTopicGossipsub*: int64

View File

@@ -7,60 +7,60 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[sets, tables]
import std/[sets, tables, options]
import rpc/[messages]
import results
export sets, tables, messages, results
export sets, tables, messages, options
type
CacheEntry* = object
msgId*: MessageId
topic*: string
mid*: MessageId
topicIds*: seq[string]
MCache* = object of RootObj
msgs*: Table[MessageId, Message]
history*: seq[seq[CacheEntry]]
pos*: int
windowSize*: Natural
func get*(c: MCache, msgId: MessageId): Opt[Message] =
if msgId in c.msgs:
try:
Opt.some(c.msgs[msgId])
except KeyError:
raiseAssert "checked"
func get*(c: MCache, mid: MessageId): Option[Message] =
if mid in c.msgs:
try: some(c.msgs[mid])
except KeyError: raiseAssert "checked"
else:
Opt.none(Message)
none(Message)
func contains*(c: MCache, msgId: MessageId): bool =
msgId in c.msgs
func contains*(c: MCache, mid: MessageId): bool =
mid in c.msgs
func put*(c: var MCache, msgId: MessageId, msg: Message) =
if not c.msgs.hasKeyOrPut(msgId, msg):
# Only add cache entry if the message was not already in the cache
c.history[c.pos].add(CacheEntry(msgId: msgId, topic: msg.topic))
c.history[0].add(CacheEntry(mid: msgId, topicIds: msg.topicIds))
func window*(c: MCache, topic: string): HashSet[MessageId] =
let len = min(c.windowSize, c.history.len)
let
len = min(c.windowSize, c.history.len)
for i in 0 ..< len:
# Work backwards from `pos` in the circular buffer
for entry in c.history[(c.pos + c.history.len - i) mod c.history.len]:
if entry.topic == topic:
result.incl(entry.msgId)
for i in 0..<len:
for entry in c.history[i]:
for t in entry.topicIds:
if t == topic:
result.incl(entry.mid)
break
func shift*(c: var MCache) =
# Shift circular buffer to write to a new position, clearing it from past
# iterations
c.pos = (c.pos + 1) mod c.history.len
for entry in c.history.pop():
c.msgs.del(entry.mid)
for entry in c.history[c.pos]:
c.msgs.del(entry.msgId)
reset(c.history[c.pos])
c.history.insert(@[])
func init*(T: type MCache, window, history: Natural): T =
T(history: newSeq[seq[CacheEntry]](history), windowSize: window)
T(
history: newSeq[seq[CacheEntry]](history),
windowSize: window
)

View File

@@ -7,14 +7,16 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[tables, sets, sequtils]
import std/[tables, sets]
import ./pubsubpeer, ../../peerid
export tables, sets
type PeerTable* = Table[string, HashSet[PubSubPeer]] # topic string to peer map
type
PeerTable* = Table[string, HashSet[PubSubPeer]] # topic string to peer map
proc hasPeerId*(t: PeerTable, topic: string, peerId: PeerId): bool =
if topic in t:
@@ -22,14 +24,15 @@ proc hasPeerId*(t: PeerTable, topic: string, peerId: PeerId): bool =
for peer in t[topic]:
if peer.peerId == peerId:
return true
except KeyError:
raiseAssert "checked with in"
except KeyError: raiseAssert "checked with in"
false
func addPeer*(table: var PeerTable, topic: string, peer: PubSubPeer): bool =
# returns true if the peer was added,
# false if it was already in the collection
not table.mgetOrPut(topic, initHashSet[PubSubPeer]()).containsOrIncl(peer)
not table.mgetOrPut(topic,
initHashSet[PubSubPeer]())
.containsOrIncl(peer)
func removePeer*(table: var PeerTable, topic: string, peer: PubSubPeer) =
table.withValue(topic, peers):
@@ -40,23 +43,11 @@ func removePeer*(table: var PeerTable, topic: string, peer: PubSubPeer) =
func hasPeer*(table: PeerTable, topic: string, peer: PubSubPeer): bool =
try:
(topic in table) and (peer in table[topic])
except KeyError:
raiseAssert "checked with in"
except KeyError: raiseAssert "checked with in"
func peers*(table: PeerTable, topic: string): int =
if topic in table:
try:
table[topic].len
except KeyError:
raiseAssert "checked with in"
else:
0
func outboundPeers*(table: PeerTable, topic: string): int =
if topic in table:
try:
table[topic].countIt(it.outbound)
except KeyError:
raiseAssert "checked with in"
try: table[topic].len
except KeyError: raiseAssert "checked with in"
else:
0

View File

@@ -13,28 +13,29 @@
## `publish<#publish.e%2CPubSub%2Cstring%2Cseq%5Bbyte%5D>`_ something on it,
## and eventually `unsubscribe<#unsubscribe%2CPubSub%2Cstring%2CTopicHandler>`_ from it.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[tables, sequtils, sets, strutils]
import chronos, chronicles, metrics
import chronos/ratelimit
import
./errors as pubsub_errors,
./pubsubpeer,
./rpc/[message, messages, protobuf],
../../switch,
../protocol,
../../crypto/crypto,
../../stream/connection,
../../peerid,
../../peerinfo,
../../errors,
../../utility
import ./errors as pubsub_errors,
./pubsubpeer,
./rpc/[message, messages, protobuf],
../../switch,
../protocol,
../../crypto/crypto,
../../stream/connection,
../../peerid,
../../peerinfo,
../../errors,
../../utility
import metrics
import stew/results
export results
export tables, sets
export PubSubPeer
export PubSubObserver
export protocol
@@ -51,119 +52,73 @@ declareGauge(libp2p_pubsub_peers, "pubsub peer instances")
declareGauge(libp2p_pubsub_topics, "pubsub subscribed topics")
declareCounter(libp2p_pubsub_subscriptions, "pubsub subscription operations")
declareCounter(libp2p_pubsub_unsubscriptions, "pubsub unsubscription operations")
declareGauge(
libp2p_pubsub_topic_handlers,
"pubsub subscribed topics handlers count",
labels = ["topic"],
)
declareGauge(libp2p_pubsub_topic_handlers, "pubsub subscribed topics handlers count", labels = ["topic"])
declareCounter(
libp2p_pubsub_validation_success, "pubsub successfully validated messages"
)
declareCounter(libp2p_pubsub_validation_success, "pubsub successfully validated messages")
declareCounter(libp2p_pubsub_validation_failure, "pubsub failed validated messages")
declareCounter(libp2p_pubsub_validation_ignore, "pubsub ignore validated messages")
declarePublicCounter(
libp2p_pubsub_messages_published, "published messages", labels = ["topic"]
)
declarePublicCounter(
libp2p_pubsub_messages_rebroadcasted, "re-broadcasted messages", labels = ["topic"]
)
declarePublicCounter(libp2p_pubsub_messages_published, "published messages", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_messages_rebroadcasted, "re-broadcasted messages", labels = ["topic"])
declarePublicCounter(
libp2p_pubsub_broadcast_subscriptions,
"pubsub broadcast subscriptions",
labels = ["topic"],
)
declarePublicCounter(
libp2p_pubsub_broadcast_unsubscriptions,
"pubsub broadcast unsubscriptions",
labels = ["topic"],
)
declarePublicCounter(
libp2p_pubsub_broadcast_messages, "pubsub broadcast messages", labels = ["topic"]
)
declarePublicCounter(libp2p_pubsub_broadcast_subscriptions, "pubsub broadcast subscriptions", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_broadcast_unsubscriptions, "pubsub broadcast unsubscriptions", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_broadcast_messages, "pubsub broadcast messages", labels = ["topic"])
declarePublicCounter(
libp2p_pubsub_received_subscriptions,
"pubsub received subscriptions",
labels = ["topic"],
)
declarePublicCounter(
libp2p_pubsub_received_unsubscriptions,
"pubsub received subscriptions",
labels = ["topic"],
)
declarePublicCounter(
libp2p_pubsub_received_messages, "pubsub received messages", labels = ["topic"]
)
declarePublicCounter(libp2p_pubsub_received_subscriptions, "pubsub received subscriptions", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_received_unsubscriptions, "pubsub received subscriptions", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_received_messages, "pubsub received messages", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_broadcast_iwant, "pubsub broadcast iwant")
declarePublicCounter(
libp2p_pubsub_broadcast_ihave, "pubsub broadcast ihave", labels = ["topic"]
)
declarePublicCounter(
libp2p_pubsub_broadcast_graft, "pubsub broadcast graft", labels = ["topic"]
)
declarePublicCounter(
libp2p_pubsub_broadcast_prune, "pubsub broadcast prune", labels = ["topic"]
)
declarePublicCounter(libp2p_pubsub_broadcast_ihave, "pubsub broadcast ihave", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_broadcast_graft, "pubsub broadcast graft", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_broadcast_prune, "pubsub broadcast prune", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_received_iwant, "pubsub broadcast iwant")
declarePublicCounter(
libp2p_pubsub_received_ihave, "pubsub broadcast ihave", labels = ["topic"]
)
declarePublicCounter(
libp2p_pubsub_received_graft, "pubsub broadcast graft", labels = ["topic"]
)
declarePublicCounter(
libp2p_pubsub_received_prune, "pubsub broadcast prune", labels = ["topic"]
)
declarePublicCounter(libp2p_pubsub_received_ihave, "pubsub broadcast ihave", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_received_graft, "pubsub broadcast graft", labels = ["topic"])
declarePublicCounter(libp2p_pubsub_received_prune, "pubsub broadcast prune", labels = ["topic"])
type
InitializationError* = object of LPError
TopicHandler* {.public.} =
proc(topic: string, data: seq[byte]): Future[void] {.gcsafe, raises: [].}
TopicHandler* {.public.} = proc(topic: string,
data: seq[byte]): Future[void] {.gcsafe, raises: [Defect].}
ValidatorHandler* {.public.} = proc(
topic: string, message: Message
): Future[ValidationResult] {.gcsafe, raises: [].}
ValidatorHandler* {.public.} = proc(topic: string,
message: Message): Future[ValidationResult] {.gcsafe, raises: [Defect].}
TopicPair* = tuple[topic: string, handler: TopicHandler]
MsgIdProvider* {.public.} = proc(m: Message): Result[MessageId, ValidationResult] {.
noSideEffect, raises: [], gcsafe
.}
MsgIdProvider* {.public.} =
proc(m: Message): Result[MessageId, ValidationResult] {.noSideEffect, raises: [Defect], gcsafe.}
SubscriptionValidator* {.public.} = proc(topic: string): bool {.raises: [], gcsafe.}
SubscriptionValidator* {.public.} =
proc(topic: string): bool {.raises: [Defect], gcsafe.}
## Every time a peer send us a subscription (even to an unknown topic),
## we have to store it, which may be an attack vector.
## This callback can be used to reject topic we're not interested in
PubSub* {.public.} = ref object of LPProtocol
switch*: Switch # the switch used to dial/connect to peers
peerInfo*: PeerInfo # this peer's info
topics*: Table[string, seq[TopicHandler]] # the topics that _we_ are interested in
peers*: Table[PeerId, PubSubPeer]
#\
switch*: Switch # the switch used to dial/connect to peers
peerInfo*: PeerInfo # this peer's info
topics*: Table[string, seq[TopicHandler]] # the topics that _we_ are interested in
peers*: Table[PeerId, PubSubPeer] #\
# Peers that we are interested to gossip with (but not necessarily
# yet connected to)
triggerSelf*: bool ## trigger own local handler on publish
verifySignature*: bool ## enable signature verification
sign*: bool ## enable message signing
triggerSelf*: bool ## trigger own local handler on publish
verifySignature*: bool ## enable signature verification
sign*: bool ## enable message signing
validators*: Table[string, HashSet[ValidatorHandler]]
observers: ref seq[PubSubObserver] # ref as in smart_ptr
msgIdProvider*: MsgIdProvider ## Turn message into message id (not nil)
msgIdProvider*: MsgIdProvider ## Turn message into message id (not nil)
msgSeqno*: uint64
anonymize*: bool ## if we omit fromPeer and seqno from RPC messages we send
subscriptionValidator*: SubscriptionValidator
# callback used to validate subscriptions
topicsHigh*: int ## the maximum number of topics a peer is allowed to subscribe to
maxMessageSize*: int
##\
anonymize*: bool ## if we omit fromPeer and seqno from RPC messages we send
subscriptionValidator*: SubscriptionValidator # callback used to validate subscriptions
topicsHigh*: int ## the maximum number of topics a peer is allowed to subscribe to
maxMessageSize*: int ##\
## the maximum raw message size we'll globally allow
## for finer tuning, check message size on topic validator
##
@@ -184,37 +139,18 @@ method unsubscribePeer*(p: PubSub, peerId: PeerId) {.base, gcsafe.} =
libp2p_pubsub_peers.set(p.peers.len.int64)
proc send*(
p: PubSub, peer: PubSubPeer, msg: RPCMsg, isHighPriority: bool
) {.raises: [].} =
## This procedure attempts to send a `msg` (of type `RPCMsg`) to the specified remote peer in the PubSub network.
proc send*(p: PubSub, peer: PubSubPeer, msg: RPCMsg) {.raises: [Defect].} =
## Attempt to send `msg` to remote peer
##
## Parameters:
## - `p`: The `PubSub` instance.
## - `peer`: An instance of `PubSubPeer` representing the peer to whom the message should be sent.
## - `msg`: The `RPCMsg` instance that contains the message to be sent.
## - `isHighPriority`: A boolean indicating whether the message should be treated as high priority.
## High priority messages are sent immediately, while low priority messages are queued and sent only after all high
## priority messages have been sent.
trace "sending pubsub message to peer", peer, msg = shortLog(msg)
peer.send(msg, p.anonymize, isHighPriority)
peer.send(msg, p.anonymize)
proc broadcast*(
p: PubSub,
sendPeers: auto, # Iteratble[PubSubPeer]
msg: RPCMsg,
isHighPriority: bool,
) {.raises: [].} =
## This procedure attempts to send a `msg` (of type `RPCMsg`) to a specified group of peers in the PubSub network.
##
## Parameters:
## - `p`: The `PubSub` instance.
## - `sendPeers`: An iterable of `PubSubPeer` instances representing the peers to whom the message should be sent.
## - `msg`: The `RPCMsg` instance that contains the message to be broadcast.
## - `isHighPriority`: A boolean indicating whether the message should be treated as high priority.
## High priority messages are sent immediately, while low priority messages are queued and sent only after all high
## priority messages have been sent.
p: PubSub,
sendPeers: auto, # Iteratble[PubSubPeer]
msg: RPCMsg) {.raises: [Defect].} =
## Attempt to send `msg` to the given peers
let npeers = sendPeers.len.int64
for sub in msg.subscriptions:
@@ -230,47 +166,50 @@ proc broadcast*(
libp2p_pubsub_broadcast_unsubscriptions.inc(npeers, labelValues = ["generic"])
for smsg in msg.messages:
let topic = smsg.topic
if p.knownTopics.contains(topic):
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = [topic])
else:
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = ["generic"])
for topic in smsg.topicIds:
if p.knownTopics.contains(topic):
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = [topic])
else:
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = ["generic"])
msg.control.withValue(control):
libp2p_pubsub_broadcast_iwant.inc(npeers * control.iwant.len.int64)
if msg.control.isSome():
libp2p_pubsub_broadcast_iwant.inc(npeers * msg.control.get().iwant.len.int64)
let control = msg.control.get()
for ihave in control.ihave:
if p.knownTopics.contains(ihave.topicID):
libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = [ihave.topicID])
if p.knownTopics.contains(ihave.topicId):
libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = [ihave.topicId])
else:
libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = ["generic"])
for graft in control.graft:
if p.knownTopics.contains(graft.topicID):
libp2p_pubsub_broadcast_graft.inc(npeers, labelValues = [graft.topicID])
if p.knownTopics.contains(graft.topicId):
libp2p_pubsub_broadcast_graft.inc(npeers, labelValues = [graft.topicId])
else:
libp2p_pubsub_broadcast_graft.inc(npeers, labelValues = ["generic"])
for prune in control.prune:
if p.knownTopics.contains(prune.topicID):
libp2p_pubsub_broadcast_prune.inc(npeers, labelValues = [prune.topicID])
if p.knownTopics.contains(prune.topicId):
libp2p_pubsub_broadcast_prune.inc(npeers, labelValues = [prune.topicId])
else:
libp2p_pubsub_broadcast_prune.inc(npeers, labelValues = ["generic"])
trace "broadcasting messages to peers", peers = sendPeers.len, msg = shortLog(msg)
trace "broadcasting messages to peers",
peers = sendPeers.len, msg = shortLog(msg)
if anyIt(sendPeers, it.hasObservers):
for peer in sendPeers:
p.send(peer, msg, isHighPriority)
p.send(peer, msg)
else:
# Fast path that only encodes message once
let encoded = encodeRpcMsg(msg, p.anonymize)
for peer in sendPeers:
asyncSpawn peer.sendEncoded(encoded, isHighPriority)
asyncSpawn peer.sendEncoded(encoded)
proc sendSubs*(
p: PubSub, peer: PubSubPeer, topics: openArray[string], subscribe: bool
) =
proc sendSubs*(p: PubSub,
peer: PubSubPeer,
topics: openArray[string],
subscribe: bool) =
## send subscriptions to remote peer
p.send(peer, RPCMsg.withSubs(topics, subscribe), isHighPriority = true)
p.send(peer, RPCMsg.withSubs(topics, subscribe))
for topic in topics:
if subscribe:
@@ -285,10 +224,8 @@ proc sendSubs*(
libp2p_pubsub_broadcast_unsubscriptions.inc(labelValues = ["generic"])
proc updateMetrics*(p: PubSub, rpcMsg: RPCMsg) =
for i in 0 ..< min(rpcMsg.subscriptions.len, p.topicsHigh):
template sub(): untyped =
rpcMsg.subscriptions[i]
for i in 0..<min(rpcMsg.subscriptions.len, p.topicsHigh):
template sub(): untyped = rpcMsg.subscriptions[i]
if sub.subscribe:
if p.knownTopics.contains(sub.topic):
libp2p_pubsub_received_subscriptions.inc(labelValues = [sub.topic])
@@ -300,70 +237,66 @@ proc updateMetrics*(p: PubSub, rpcMsg: RPCMsg) =
else:
libp2p_pubsub_received_unsubscriptions.inc(labelValues = ["generic"])
for i in 0 ..< rpcMsg.messages.len():
let topic = rpcMsg.messages[i].topic
if p.knownTopics.contains(topic):
libp2p_pubsub_received_messages.inc(labelValues = [topic])
else:
libp2p_pubsub_received_messages.inc(labelValues = ["generic"])
for i in 0..<rpcMsg.messages.len():
template smsg: untyped = rpcMsg.messages[i]
for j in 0..<smsg.topicIds.len():
template topic: untyped = smsg.topicIds[j]
if p.knownTopics.contains(topic):
libp2p_pubsub_received_messages.inc(labelValues = [topic])
else:
libp2p_pubsub_received_messages.inc(labelValues = ["generic"])
rpcMsg.control.withValue(control):
libp2p_pubsub_received_iwant.inc(control.iwant.len.int64)
if rpcMsg.control.isSome():
libp2p_pubsub_received_iwant.inc(rpcMsg.control.get().iwant.len.int64)
template control: untyped = rpcMsg.control.unsafeGet()
for ihave in control.ihave:
if p.knownTopics.contains(ihave.topicID):
libp2p_pubsub_received_ihave.inc(labelValues = [ihave.topicID])
if p.knownTopics.contains(ihave.topicId):
libp2p_pubsub_received_ihave.inc(labelValues = [ihave.topicId])
else:
libp2p_pubsub_received_ihave.inc(labelValues = ["generic"])
for graft in control.graft:
if p.knownTopics.contains(graft.topicID):
libp2p_pubsub_received_graft.inc(labelValues = [graft.topicID])
if p.knownTopics.contains(graft.topicId):
libp2p_pubsub_received_graft.inc(labelValues = [graft.topicId])
else:
libp2p_pubsub_received_graft.inc(labelValues = ["generic"])
for prune in control.prune:
if p.knownTopics.contains(prune.topicID):
libp2p_pubsub_received_prune.inc(labelValues = [prune.topicID])
if p.knownTopics.contains(prune.topicId):
libp2p_pubsub_received_prune.inc(labelValues = [prune.topicId])
else:
libp2p_pubsub_received_prune.inc(labelValues = ["generic"])
method rpcHandler*(
p: PubSub, peer: PubSubPeer, data: seq[byte]
): Future[void] {.base, async.} =
method rpcHandler*(p: PubSub,
peer: PubSubPeer,
rpcMsg: RPCMsg): Future[void] {.base, async.} =
## Handler that must be overridden by concrete implementation
raiseAssert "Unimplemented"
method onNewPeer(p: PubSub, peer: PubSubPeer) {.base, gcsafe.} =
discard
method onNewPeer(p: PubSub, peer: PubSubPeer) {.base, gcsafe.} = discard
method onPubSubPeerEvent*(
p: PubSub, peer: PubSubPeer, event: PubSubPeerEvent
) {.base, gcsafe.} =
method onPubSubPeerEvent*(p: PubSub, peer: PubSubPeer, event: PubSubPeerEvent) {.base, gcsafe.} =
# Peer event is raised for the send connection in particular
case event.kind
of PubSubPeerEventKind.StreamOpened:
of PubSubPeerEventKind.Connected:
if p.topics.len > 0:
p.sendSubs(peer, toSeq(p.topics.keys), true)
of PubSubPeerEventKind.StreamClosed:
discard
of PubSubPeerEventKind.DisconnectionRequested:
of PubSubPeerEventKind.Disconnected:
discard
method getOrCreatePeer*(
p: PubSub, peerId: PeerId, protosToDial: seq[string], protoNegotiated: string = ""
): PubSubPeer {.base, gcsafe.} =
proc getOrCreatePeer*(
p: PubSub,
peerId: PeerId,
protos: seq[string]): PubSubPeer =
p.peers.withValue(peerId, peer):
if peer[].codec == "":
peer[].codec = protoNegotiated
return peer[]
proc getConn(): Future[Connection] {.async.} =
return await p.switch.dial(peerId, protosToDial)
return await p.switch.dial(peerId, protos)
proc onEvent(peer: PubSubPeer, event: PubSubPeerEvent) {.gcsafe.} =
p.onPubSubPeerEvent(peer, event)
# create new pubsub peer
let pubSubPeer =
PubSubPeer.new(peerId, getConn, onEvent, protoNegotiated, p.maxMessageSize)
let pubSubPeer = PubSubPeer.new(peerId, getConn, onEvent, protos[0], p.maxMessageSize)
debug "created new pubsub peer", peerId
p.peers[peerId] = pubSubPeer
@@ -379,7 +312,7 @@ method getOrCreatePeer*(
proc handleData*(p: PubSub, topic: string, data: seq[byte]): Future[void] =
# Start work on all data handlers without copying data into closure like
# happens on {.async.} transformation
p.topics.withValue(topic, handlers):
p.topics.withValue(topic, handlers) do:
var futs = newSeq[Future[void]]()
for handler in handlers[]:
@@ -396,7 +329,7 @@ proc handleData*(p: PubSub, topic: string, data: seq[byte]): Future[void] =
except CancelledError:
# propagate cancellation
for fut in futs:
if not (fut.finished):
if not(fut.finished):
fut.cancel()
# check for errors in futures
@@ -404,7 +337,6 @@ proc handleData*(p: PubSub, topic: string, data: seq[byte]): Future[void] =
if fut.failed:
let err = fut.readError()
warn "Error in topic handler", msg = err.msg
return waiter()
# Fast path - futures finished synchronously or nobody cared about data
@@ -412,7 +344,9 @@ proc handleData*(p: PubSub, topic: string, data: seq[byte]): Future[void] =
res.complete()
return res
method handleConn*(p: PubSub, conn: Connection, proto: string) {.base, async.} =
method handleConn*(p: PubSub,
conn: Connection,
proto: string) {.base, async.} =
## handle incoming connections
##
## this proc will:
@@ -424,11 +358,11 @@ method handleConn*(p: PubSub, conn: Connection, proto: string) {.base, async.} =
## that we're interested in
##
proc handler(peer: PubSubPeer, data: seq[byte]): Future[void] =
proc handler(peer: PubSubPeer, msg: RPCMsg): Future[void] =
# call pubsub rpc handler
p.rpcHandler(peer, data)
p.rpcHandler(peer, msg)
let peer = p.getOrCreatePeer(conn.peerId, @[], proto)
let peer = p.getOrCreatePeer(conn.peerId, @[proto])
try:
peer.handler = handler
@@ -454,38 +388,33 @@ proc updateTopicMetrics(p: PubSub, topic: string) =
libp2p_pubsub_topics.set(p.topics.len.int64)
if p.knownTopics.contains(topic):
p.topics.withValue(topic, handlers):
p.topics.withValue(topic, handlers) do:
libp2p_pubsub_topic_handlers.set(handlers[].len.int64, labelValues = [topic])
do:
libp2p_pubsub_topic_handlers.set(0, labelValues = [topic])
else:
var others: int64 = 0
for key, val in p.topics:
if key notin p.knownTopics:
others += 1
if key notin p.knownTopics: others += 1
libp2p_pubsub_topic_handlers.set(others, labelValues = ["other"])
method onTopicSubscription*(
p: PubSub, topic: string, subscribed: bool
) {.base, gcsafe.} =
method onTopicSubscription*(p: PubSub, topic: string, subscribed: bool) {.base, gcsafe.} =
# Called when subscribe is called the first time for a topic or unsubscribe
# removes the last handler
# Notify others that we are no longer interested in the topic
for _, peer in p.peers:
# If we don't have a sendConn yet, we will
# send the full sub list when we get the sendConn,
# so no need to send it here
if peer.hasSendConn:
p.sendSubs(peer, [topic], subscribed)
p.sendSubs(peer, [topic], subscribed)
if subscribed:
libp2p_pubsub_subscriptions.inc()
else:
libp2p_pubsub_unsubscriptions.inc()
proc unsubscribe*(p: PubSub, topic: string, handler: TopicHandler) {.public.} =
proc unsubscribe*(p: PubSub,
topic: string,
handler: TopicHandler) {.public.} =
## unsubscribe from a ``topic`` string
##
p.topics.withValue(topic, handlers):
@@ -514,7 +443,9 @@ proc unsubscribeAll*(p: PubSub, topic: string) {.public, gcsafe.} =
p.updateTopicMetrics(topic)
proc subscribe*(p: PubSub, topic: string, handler: TopicHandler) {.public.} =
proc subscribe*(p: PubSub,
topic: string,
handler: TopicHandler) {.public.} =
## subscribe to a topic
##
## ``topic`` - a string topic to subscribe to
@@ -528,7 +459,7 @@ proc subscribe*(p: PubSub, topic: string, handler: TopicHandler) {.public.} =
warn "Trying to subscribe to a topic not passing validation!", topic
return
p.topics.withValue(topic, handlers):
p.topics.withValue(topic, handlers) do:
# Already subscribed, just adding another handler
handlers[].add(handler)
do:
@@ -540,9 +471,9 @@ proc subscribe*(p: PubSub, topic: string, handler: TopicHandler) {.public.} =
p.updateTopicMetrics(topic)
method publish*(
p: PubSub, topic: string, data: seq[byte]
): Future[int] {.base, async, public.} =
method publish*(p: PubSub,
topic: string,
data: seq[byte]): Future[int] {.base, async, public.} =
## publish to a ``topic``
##
## The return value is the number of neighbours that we attempted to send the
@@ -554,43 +485,42 @@ method publish*(
return 0
method initPubSub*(p: PubSub) {.base, raises: [InitializationError].} =
method initPubSub*(p: PubSub)
{.base, raises: [Defect, InitializationError].} =
## perform pubsub initialization
p.observers = new(seq[PubSubObserver])
if p.msgIdProvider == nil:
p.msgIdProvider = defaultMsgIdProvider
method addValidator*(
p: PubSub, topic: varargs[string], hook: ValidatorHandler
) {.base, public, gcsafe.} =
method addValidator*(p: PubSub,
topic: varargs[string],
hook: ValidatorHandler) {.base, public, gcsafe.} =
## Add a validator to a `topic`. Each new message received in this
## will be sent to `hook`. `hook` can return either `Accept`,
## `Ignore` or `Reject` (which can descore the peer)
for t in topic:
trace "adding validator for topic", topic = t
trace "adding validator for topic", topicId = t
p.validators.mgetOrPut(t, HashSet[ValidatorHandler]()).incl(hook)
method removeValidator*(
p: PubSub, topic: varargs[string], hook: ValidatorHandler
) {.base, public.} =
method removeValidator*(p: PubSub,
topic: varargs[string],
hook: ValidatorHandler) {.base, public.} =
for t in topic:
p.validators.withValue(t, validators):
validators[].excl(hook)
if validators[].len() == 0:
p.validators.del(t)
method validate*(
p: PubSub, message: Message
): Future[ValidationResult] {.async, base.} =
method validate*(p: PubSub, message: Message): Future[ValidationResult] {.async, base.} =
var pending: seq[Future[ValidationResult]]
trace "about to validate message"
let topic = message.topic
trace "looking for validators on topic",
topic = topic, registered = toSeq(p.validators.keys)
if topic in p.validators:
trace "running validators for topic", topic = topic
for validator in p.validators[topic]:
pending.add(validator(topic, message))
for topic in message.topicIds:
trace "looking for validators on topic", topicId = topic,
registered = toSeq(p.validators.keys)
if topic in p.validators:
trace "running validators for topic", topicId = topic
for validator in p.validators[topic]:
pending.add(validator(topic, message))
result = ValidationResult.Accept
let futs = await allFinished(pending)
@@ -613,22 +543,21 @@ method validate*(
libp2p_pubsub_validation_ignore.inc()
proc init*[PubParams: object | bool](
P: typedesc[PubSub],
switch: Switch,
triggerSelf: bool = false,
anonymize: bool = false,
verifySignature: bool = true,
sign: bool = true,
msgIdProvider: MsgIdProvider = defaultMsgIdProvider,
subscriptionValidator: SubscriptionValidator = nil,
maxMessageSize: int = 1024 * 1024,
rng: ref HmacDrbgContext = newRng(),
parameters: PubParams = false,
): P {.raises: [InitializationError], public.} =
P: typedesc[PubSub],
switch: Switch,
triggerSelf: bool = false,
anonymize: bool = false,
verifySignature: bool = true,
sign: bool = true,
msgIdProvider: MsgIdProvider = defaultMsgIdProvider,
subscriptionValidator: SubscriptionValidator = nil,
maxMessageSize: int = 1024 * 1024,
rng: ref HmacDrbgContext = newRng(),
parameters: PubParams = false): P
{.raises: [Defect, InitializationError], public.} =
let pubsub =
when PubParams is bool:
P(
switch: switch,
P(switch: switch,
peerInfo: switch.peerInfo,
triggerSelf: triggerSelf,
anonymize: anonymize,
@@ -638,11 +567,9 @@ proc init*[PubParams: object | bool](
subscriptionValidator: subscriptionValidator,
maxMessageSize: maxMessageSize,
rng: rng,
topicsHigh: int.high,
)
topicsHigh: int.high)
else:
P(
switch: switch,
P(switch: switch,
peerInfo: switch.peerInfo,
triggerSelf: triggerSelf,
anonymize: anonymize,
@@ -653,8 +580,7 @@ proc init*[PubParams: object | bool](
parameters: parameters,
maxMessageSize: maxMessageSize,
rng: rng,
topicsHigh: int.high,
)
topicsHigh: int.high)
proc peerEventHandler(peerId: PeerId, event: PeerEvent) {.async.} =
if event.kind == PeerEventKind.Joined:
@@ -671,10 +597,9 @@ proc init*[PubParams: object | bool](
return pubsub
proc addObserver*(p: PubSub, observer: PubSubObserver) {.public.} =
p.observers[] &= observer
proc addObserver*(p: PubSub; observer: PubSubObserver) {.public.} = p.observers[] &= observer
proc removeObserver*(p: PubSub, observer: PubSubObserver) {.public.} =
proc removeObserver*(p: PubSub; observer: PubSubObserver) {.public.} =
let idx = p.observers[].find(observer)
if idx != -1:
p.observers[].del(idx)

View File

@@ -7,94 +7,54 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import std/[sequtils, strutils, tables, hashes, options, sets, deques]
import std/[sequtils, strutils, tables, hashes, options]
import stew/results
import chronos, chronicles, nimcrypto/sha2, metrics
import chronos/ratelimit
import
rpc/[messages, message, protobuf],
../../peerid,
../../peerinfo,
../../stream/connection,
../../crypto/crypto,
../../protobuf/minprotobuf,
../../utility
import rpc/[messages, message, protobuf],
../../peerid,
../../peerinfo,
../../stream/connection,
../../crypto/crypto,
../../protobuf/minprotobuf,
../../utility
export peerid, connection, deques
export peerid, connection
logScope:
topics = "libp2p pubsubpeer"
when defined(libp2p_expensive_metrics):
declareCounter(
libp2p_pubsub_sent_messages, "number of messages sent", labels = ["id", "topic"]
)
declareCounter(
libp2p_pubsub_skipped_received_messages,
"number of received skipped messages",
labels = ["id"],
)
declareCounter(
libp2p_pubsub_skipped_sent_messages,
"number of sent skipped messages",
labels = ["id"],
)
when defined(pubsubpeer_queue_metrics):
declareGauge(
libp2p_gossipsub_priority_queue_size,
"the number of messages in the priority queue",
labels = ["id"],
)
declareGauge(
libp2p_gossipsub_non_priority_queue_size,
"the number of messages in the non-priority queue",
labels = ["id"],
)
declareCounter(
libp2p_pubsub_disconnects_over_non_priority_queue_limit,
"number of peers disconnected due to over non-prio queue capacity",
)
const DefaultMaxNumElementsInNonPriorityQueue* = 1024
declareCounter(libp2p_pubsub_sent_messages, "number of messages sent", labels = ["id", "topic"])
declareCounter(libp2p_pubsub_received_messages, "number of messages received", labels = ["id", "topic"])
declareCounter(libp2p_pubsub_skipped_received_messages, "number of received skipped messages", labels = ["id"])
declareCounter(libp2p_pubsub_skipped_sent_messages, "number of sent skipped messages", labels = ["id"])
type
PeerRateLimitError* = object of CatchableError
PubSubObserver* = ref object
onRecv*: proc(peer: PubSubPeer, msgs: var RPCMsg) {.gcsafe, raises: [].}
onSend*: proc(peer: PubSubPeer, msgs: var RPCMsg) {.gcsafe, raises: [].}
onRecv*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe, raises: [Defect].}
onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe, raises: [Defect].}
PubSubPeerEventKind* {.pure.} = enum
StreamOpened
StreamClosed
DisconnectionRequested
# tells gossipsub that the transport connection to the peer should be closed
Connected
Disconnected
PubSubPeerEvent* = object
kind*: PubSubPeerEventKind
GetConn* = proc(): Future[Connection] {.gcsafe, raises: [].}
DropConn* = proc(peer: PubSubPeer) {.gcsafe, raises: [].}
# have to pass peer as it's unknown during init
OnEvent* = proc(peer: PubSubPeer, event: PubSubPeerEvent) {.gcsafe, raises: [].}
RpcMessageQueue* = ref object
# Tracks async tasks for sending high-priority peer-published messages.
sendPriorityQueue: Deque[Future[void]]
# Queue for lower-priority messages, like "IWANT" replies and relay messages.
nonPriorityQueue: AsyncQueue[seq[byte]]
# Task for processing non-priority message queue.
sendNonPriorityTask: Future[void]
GetConn* = proc(): Future[Connection] {.gcsafe, raises: [Defect].}
DropConn* = proc(peer: PubSubPeer) {.gcsafe, raises: [Defect].} # have to pass peer as it's unknown during init
OnEvent* = proc(peer: PubSubPeer, event: PubSubPeerEvent) {.gcsafe, raises: [Defect].}
PubSubPeer* = ref object of RootObj
getConn*: GetConn # callback to establish a new send connection
onEvent*: OnEvent # Connectivity updates for peer
codec*: string # the protocol that this peer joined from
sendConn*: Connection # cached send connection
getConn*: GetConn # callback to establish a new send connection
onEvent*: OnEvent # Connectivity updates for peer
codec*: string # the protocol that this peer joined from
sendConn*: Connection # cached send connection
connectedFut: Future[void]
address*: Option[MultiAddress]
peerId*: PeerId
@@ -102,41 +62,21 @@ type
observers*: ref seq[PubSubObserver] # ref as in smart_ptr
score*: float64
sentIHaves*: Deque[HashSet[MessageId]]
iDontWants*: Deque[HashSet[SaltedId]]
## IDONTWANT contains unvalidated message id:s which may be long and/or
## expensive to look up, so we apply the same salting to them as during
## unvalidated message processing
iWantBudget*: int
iHaveBudget*: int
pingBudget*: int
maxMessageSize: int
appScore*: float64 # application specific score
behaviourPenalty*: float64 # the eventual penalty score
overheadRateLimitOpt*: Opt[TokenBucket]
rpcmessagequeue: RpcMessageQueue
maxNumElementsInNonPriorityQueue*: int
# The max number of elements allowed in the non-priority queue.
disconnected: bool
RPCHandler* =
proc(peer: PubSubPeer, data: seq[byte]): Future[void] {.gcsafe, raises: [].}
when defined(libp2p_agents_metrics):
func shortAgent*(p: PubSubPeer): string =
if p.sendConn.isNil or p.sendConn.getWrapped().isNil:
"unknown"
else:
#TODO the sendConn is setup before identify,
#so we have to read the parents short agent..
p.sendConn.getWrapped().shortAgent
proc getAgent*(peer: PubSubPeer): string =
return
when defined(libp2p_agents_metrics):
if peer.shortAgent.len > 0: peer.shortAgent else: "unknown"
else:
"unknown"
shortAgent*: string
RPCHandler* = proc(peer: PubSubPeer, msg: RPCMsg): Future[void]
{.gcsafe, raises: [Defect].}
proc queuedSendBytes*(p: PubSubPeer): int =
if p.sendConn.isNil: 0
else: p.sendConn.queuedSendBytes()
func hash*(p: PubSubPeer): Hash =
p.peerId.hash
@@ -145,15 +85,13 @@ func `==`*(a, b: PubSubPeer): bool =
a.peerId == b.peerId
func shortLog*(p: PubSubPeer): string =
if p.isNil:
"PubSubPeer(nil)"
else:
shortLog(p.peerId)
chronicles.formatIt(PubSubPeer):
shortLog(it)
if p.isNil: "PubSubPeer(nil)"
else: shortLog(p.peerId)
chronicles.formatIt(PubSubPeer): shortLog(it)
proc connected*(p: PubSubPeer): bool =
not p.sendConn.isNil and not (p.sendConn.closed or p.sendConn.atEof)
not p.sendConn.isNil and not
(p.sendConn.closed or p.sendConn.atEof)
proc hasObservers*(p: PubSubPeer): bool =
p.observers != nil and anyIt(p.observers[], it != nil)
@@ -163,24 +101,28 @@ func outbound*(p: PubSubPeer): bool =
# in order to give priotity to connections we make
# https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.1.md#outbound-mesh-quotas
# This behaviour is presrcibed to counter sybil attacks and ensures that a coordinated inbound attack can never fully take over the mesh
if not p.sendConn.isNil and p.sendConn.transportDir == Direction.Out: true else: false
if not p.sendConn.isNil and p.sendConn.transportDir == Direction.Out:
true
else:
false
proc recvObservers*(p: PubSubPeer, msg: var RPCMsg) =
proc recvObservers(p: PubSubPeer, msg: var RPCMsg) =
# trigger hooks
if not (isNil(p.observers)) and p.observers[].len > 0:
if not(isNil(p.observers)) and p.observers[].len > 0:
for obs in p.observers[]:
if not (isNil(obs)): # TODO: should never be nil, but...
if not(isNil(obs)): # TODO: should never be nil, but...
obs.onRecv(p, msg)
proc sendObservers(p: PubSubPeer, msg: var RPCMsg) =
# trigger hooks
if not (isNil(p.observers)) and p.observers[].len > 0:
if not(isNil(p.observers)) and p.observers[].len > 0:
for obs in p.observers[]:
if not (isNil(obs)): # TODO: should never be nil, but...
if not(isNil(obs)): # TODO: should never be nil, but...
obs.onSend(p, msg)
proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
debug "starting pubsub read loop", conn, peer = p, closed = conn.closed
debug "starting pubsub read loop",
conn, peer = p, closed = conn.closed
try:
try:
while not conn.atEof:
@@ -188,16 +130,31 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
var data = await conn.readLp(p.maxMessageSize)
trace "read data from peer",
conn, peer = p, closed = conn.closed, data = data.shortLog
conn, peer = p, closed = conn.closed,
data = data.shortLog
await p.handler(p, data)
var rmsg = decodeRpcMsg(data)
data = newSeq[byte]() # Release memory
except PeerRateLimitError as exc:
debug "Peer rate limit exceeded, exiting read while",
conn, peer = p, error = exc.msg
except CatchableError as exc:
debug "Exception occurred in PubSubPeer.handle",
conn, peer = p, closed = conn.closed, exc = exc.msg
if rmsg.isErr():
notice "failed to decode msg from peer",
conn, peer = p, closed = conn.closed,
err = rmsg.error()
break
trace "decoded msg from peer",
conn, peer = p, closed = conn.closed,
msg = rmsg.get().shortLog
# trigger hooks
p.recvObservers(rmsg.get())
when defined(libp2p_expensive_metrics):
for m in rmsg.get().messages:
for t in m.topicIDs:
# metrics
libp2p_pubsub_received_messages.inc(labelValues = [$p.peerId, t])
await p.handler(p, rmsg.get())
finally:
await conn.close()
except CancelledError:
@@ -208,31 +165,14 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
trace "Exception occurred in PubSubPeer.handle",
conn, peer = p, closed = conn.closed, exc = exc.msg
finally:
debug "exiting pubsub read loop", conn, peer = p, closed = conn.closed
proc closeSendConn(p: PubSubPeer, event: PubSubPeerEventKind) {.async.} =
if p.sendConn != nil:
trace "Removing send connection", p, conn = p.sendConn
await p.sendConn.close()
p.sendConn = nil
if not p.connectedFut.finished:
p.connectedFut.complete()
try:
if p.onEvent != nil:
p.onEvent(p, PubSubPeerEvent(kind: event))
except CancelledError as exc:
raise exc
except CatchableError as exc:
debug "Errors during diconnection events", error = exc.msg
# don't cleanup p.address else we leak some gossip stat table
debug "exiting pubsub read loop",
conn, peer = p, closed = conn.closed
proc connectOnce(p: PubSubPeer): Future[void] {.async.} =
try:
if p.connectedFut.finished:
p.connectedFut = newFuture[void]()
let newConn = await p.getConn().wait(5.seconds)
let newConn = await p.getConn()
if newConn.isNil:
raise (ref LPError)(msg: "Cannot establish send connection")
@@ -241,24 +181,29 @@ proc connectOnce(p: PubSubPeer): Future[void] {.async.} =
# stop working so we make an effort to only keep a single channel alive
trace "Get new send connection", p, newConn
# Careful to race conditions here.
# Topic subscription relies on either connectedFut
# to be completed, or onEvent to be called later
p.connectedFut.complete()
p.sendConn = newConn
p.address =
if p.sendConn.observedAddr.isSome:
some(p.sendConn.observedAddr.get)
else:
none(MultiAddress)
p.address = if p.sendConn.observedAddr.isSome: some(p.sendConn.observedAddr.get) else: none(MultiAddress)
if p.onEvent != nil:
p.onEvent(p, PubSubPeerEvent(kind: PubSubPeerEventKind.StreamOpened))
p.onEvent(p, PubSubPeerEvent(kind: PubSubPeerEventKind.Connected))
await handle(p, newConn)
finally:
await p.closeSendConn(PubSubPeerEventKind.StreamClosed)
if p.sendConn != nil:
trace "Removing send connection", p, conn = p.sendConn
await p.sendConn.close()
p.sendConn = nil
try:
if p.onEvent != nil:
p.onEvent(p, PubSubPeerEvent(kind: PubSubPeerEventKind.Disconnected))
except CancelledError as exc:
raise exc
except CatchableError as exc:
debug "Errors during diconnection events", error = exc.msg
# don't cleanup p.address else we leak some gossip stat table
proc connectImpl(p: PubSubPeer) {.async.} =
try:
@@ -266,10 +211,6 @@ proc connectImpl(p: PubSubPeer) {.async.} =
# send connection might get disconnected due to a timeout or an unrelated
# issue so we try to get a new on
while true:
if p.disconnected:
if not p.connectedFut.finished:
p.connectedFut.complete()
return
await connectOnce(p)
except CatchableError as exc: # never cancelled
debug "Could not establish send connection", msg = exc.msg
@@ -280,36 +221,36 @@ proc connect*(p: PubSubPeer) =
asyncSpawn connectImpl(p)
proc hasSendConn*(p: PubSubPeer): bool =
p.sendConn != nil
template sendMetrics(msg: RPCMsg): untyped =
when defined(libp2p_expensive_metrics):
for x in msg.messages:
# metrics
libp2p_pubsub_sent_messages.inc(labelValues = [$p.peerId, x.topic])
for t in x.topicIDs:
# metrics
libp2p_pubsub_sent_messages.inc(labelValues = [$p.peerId, t])
proc clearSendPriorityQueue(p: PubSubPeer) =
if p.rpcmessagequeue.sendPriorityQueue.len == 0:
return # fast path
proc sendEncoded*(p: PubSubPeer, msg: seq[byte]) {.raises: [Defect], async.} =
doAssert(not isNil(p), "pubsubpeer nil!")
while p.rpcmessagequeue.sendPriorityQueue.len > 0 and
p.rpcmessagequeue.sendPriorityQueue[0].finished:
discard p.rpcmessagequeue.sendPriorityQueue.popFirst()
if msg.len <= 0:
debug "empty message, skipping", p, msg = shortLog(msg)
return
while p.rpcmessagequeue.sendPriorityQueue.len > 0 and
p.rpcmessagequeue.sendPriorityQueue[^1].finished:
discard p.rpcmessagequeue.sendPriorityQueue.popLast()
if msg.len > p.maxMessageSize:
info "trying to send a too big for pubsub", maxSize=p.maxMessageSize, msgSize=msg.len
return
when defined(pubsubpeer_queue_metrics):
libp2p_gossipsub_priority_queue_size.set(
value = p.rpcmessagequeue.sendPriorityQueue.len.int64, labelValues = [$p.peerId]
)
if p.sendConn == nil:
discard await p.connectedFut.withTimeout(1.seconds)
var conn = p.sendConn
if conn == nil or conn.closed():
debug "No send connection, skipping message", p, msg = shortLog(msg)
return
trace "sending encoded msgs to peer", conn, encoded = shortLog(msg)
proc sendMsgContinue(conn: Connection, msgFut: Future[void]) {.async.} =
# Continuation for a pending `sendMsg` future from below
try:
await msgFut
await conn.writeLp(msg)
trace "sent pubsub message to remote", conn
except CatchableError as exc: # never cancelled
# Because we detach the send call from the currently executing task using
@@ -320,230 +261,41 @@ proc sendMsgContinue(conn: Connection, msgFut: Future[void]) {.async.} =
await conn.close() # This will clean up the send connection
proc sendMsgSlow(p: PubSubPeer, msg: seq[byte]) {.async.} =
# Slow path of `sendMsg` where msg is held in memory while send connection is
# being set up
if p.sendConn == nil:
# Wait for a send conn to be setup. `connectOnce` will
# complete this even if the sendConn setup failed
discard await race(p.connectedFut)
proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [Defect].} =
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
var conn = p.sendConn
if conn == nil or conn.closed():
debug "No send connection", p, msg = shortLog(msg)
return
trace "sending encoded msg to peer", conn, encoded = shortLog(msg)
await sendMsgContinue(conn, conn.writeLp(msg))
proc sendMsg(p: PubSubPeer, msg: seq[byte]): Future[void] =
if p.sendConn != nil and not p.sendConn.closed():
# Fast path that avoids copying msg (which happens for {.async.})
let conn = p.sendConn
trace "sending encoded msg to peer", conn, encoded = shortLog(msg)
let f = conn.writeLp(msg)
if not f.completed():
sendMsgContinue(conn, f)
else:
f
else:
sendMsgSlow(p, msg)
proc sendEncoded*(p: PubSubPeer, msg: seq[byte], isHighPriority: bool): Future[void] =
## Asynchronously sends an encoded message to a specified `PubSubPeer`.
##
## Parameters:
## - `p`: The `PubSubPeer` instance to which the message is to be sent.
## - `msg`: The message to be sent, encoded as a sequence of bytes (`seq[byte]`).
## - `isHighPriority`: A boolean indicating whether the message should be treated as high priority.
## High priority messages are sent immediately, while low priority messages are queued and sent only after all high
## priority messages have been sent.
doAssert(not isNil(p), "pubsubpeer nil!")
p.clearSendPriorityQueue()
# When queues are empty, skipping the non-priority queue for low priority
# messages reduces latency
let emptyQueues =
(
p.rpcmessagequeue.sendPriorityQueue.len() +
p.rpcmessagequeue.nonPriorityQueue.len()
) == 0
if msg.len <= 0:
debug "empty message, skipping", p, msg = shortLog(msg)
Future[void].completed()
elif msg.len > p.maxMessageSize:
info "trying to send a msg too big for pubsub",
maxSize = p.maxMessageSize, msgSize = msg.len
Future[void].completed()
elif isHighPriority or emptyQueues:
let f = p.sendMsg(msg)
if not f.finished:
p.rpcmessagequeue.sendPriorityQueue.addLast(f)
when defined(pubsubpeer_queue_metrics):
libp2p_gossipsub_priority_queue_size.inc(labelValues = [$p.peerId])
f
else:
if len(p.rpcmessagequeue.nonPriorityQueue) >= p.maxNumElementsInNonPriorityQueue:
if not p.disconnected:
p.disconnected = true
libp2p_pubsub_disconnects_over_non_priority_queue_limit.inc()
p.closeSendConn(PubSubPeerEventKind.DisconnectionRequested)
else:
Future[void].completed()
else:
let f = p.rpcmessagequeue.nonPriorityQueue.addLast(msg)
when defined(pubsubpeer_queue_metrics):
libp2p_gossipsub_non_priority_queue_size.inc(labelValues = [$p.peerId])
f
iterator splitRPCMsg(
peer: PubSubPeer, rpcMsg: RPCMsg, maxSize: int, anonymize: bool
): seq[byte] =
## This iterator takes an `RPCMsg` and sequentially repackages its Messages into new `RPCMsg` instances.
## Each new `RPCMsg` accumulates Messages until reaching the specified `maxSize`. If a single Message
## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send.
## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes.
var currentRPCMsg = rpcMsg
currentRPCMsg.messages = newSeq[Message]()
var currentSize = byteSize(currentRPCMsg)
for msg in rpcMsg.messages:
let msgSize = byteSize(msg)
# Check if adding the next message will exceed maxSize
if float(currentSize + msgSize) * 1.1 > float(maxSize):
# Guessing 10% protobuf overhead
if currentRPCMsg.messages.len == 0:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
continue # Skip this message
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
yield encodeRpcMsg(currentRPCMsg, anonymize)
currentRPCMsg = RPCMsg()
currentSize = 0
currentRPCMsg.messages.add(msg)
currentSize += msgSize
# Check if there is a non-empty currentRPCMsg left to be added
if currentSize > 0 and currentRPCMsg.messages.len > 0:
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
yield encodeRpcMsg(currentRPCMsg, anonymize)
else:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
proc send*(
p: PubSubPeer, msg: RPCMsg, anonymize: bool, isHighPriority: bool
) {.raises: [].} =
## Asynchronously sends an `RPCMsg` to a specified `PubSubPeer` with an option for anonymization.
##
## Parameters:
## - `p`: The `PubSubPeer` instance to which the message is to be sent.
## - `msg`: The `RPCMsg` instance representing the message to be sent.
## - `anonymize`: A boolean flag indicating whether the message should be sent with anonymization.
## - `isHighPriority`: A boolean flag indicating whether the message should be treated as high priority.
## High priority messages are sent immediately, while low priority messages are queued and sent only after all high
## priority messages have been sent.
# When sending messages, we take care to re-encode them with the right
# anonymization flag to ensure that we're not penalized for sending invalid
# or malicious data on the wire - in particular, re-encoding protects against
# some forms of valid but redundantly encoded protobufs with unknown or
# duplicated fields
let encoded =
if p.hasObservers():
var mm = msg
# trigger send hooks
p.sendObservers(mm)
sendMetrics(mm)
encodeRpcMsg(mm, anonymize)
else:
# If there are no send hooks, we redundantly re-encode the message to
# protobuf for every peer - this could easily be improved!
sendMetrics(msg)
encodeRpcMsg(msg, anonymize)
if encoded.len > p.maxMessageSize and msg.messages.len > 1:
for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize):
asyncSpawn p.sendEncoded(encodedSplitMsg, isHighPriority)
let encoded = if p.hasObservers():
var mm = msg
# trigger send hooks
p.sendObservers(mm)
sendMetrics(mm)
encodeRpcMsg(mm, anonymize)
else:
# If the message size is within limits, send it as is
trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg)
asyncSpawn p.sendEncoded(encoded, isHighPriority)
# If there are no send hooks, we redundantly re-encode the message to
# protobuf for every peer - this could easily be improved!
sendMetrics(msg)
encodeRpcMsg(msg, anonymize)
proc canAskIWant*(p: PubSubPeer, msgId: MessageId): bool =
for sentIHave in p.sentIHaves.mitems():
if msgId in sentIHave:
sentIHave.excl(msgId)
return true
return false
proc sendNonPriorityTask(p: PubSubPeer) {.async.} =
while true:
# we send non-priority messages only if there are no pending priority messages
let msg = await p.rpcmessagequeue.nonPriorityQueue.popFirst()
while p.rpcmessagequeue.sendPriorityQueue.len > 0:
p.clearSendPriorityQueue()
# waiting for the last future minimizes the number of times we have to
# wait for something (each wait = performance cost) -
# clearSendPriorityQueue ensures we're not waiting for an already-finished
# future
if p.rpcmessagequeue.sendPriorityQueue.len > 0:
# `race` prevents `p.rpcmessagequeue.sendPriorityQueue[^1]` from being
# cancelled when this task is cancelled
discard await race(p.rpcmessagequeue.sendPriorityQueue[^1])
when defined(pubsubpeer_queue_metrics):
libp2p_gossipsub_non_priority_queue_size.dec(labelValues = [$p.peerId])
await p.sendMsg(msg)
proc startSendNonPriorityTask(p: PubSubPeer) =
debug "starting sendNonPriorityTask", p
if p.rpcmessagequeue.sendNonPriorityTask.isNil:
p.rpcmessagequeue.sendNonPriorityTask = p.sendNonPriorityTask()
proc stopSendNonPriorityTask*(p: PubSubPeer) =
if not p.rpcmessagequeue.sendNonPriorityTask.isNil:
debug "stopping sendNonPriorityTask", p
p.rpcmessagequeue.sendNonPriorityTask.cancelSoon()
p.rpcmessagequeue.sendNonPriorityTask = nil
p.rpcmessagequeue.sendPriorityQueue.clear()
p.rpcmessagequeue.nonPriorityQueue.clear()
when defined(pubsubpeer_queue_metrics):
libp2p_gossipsub_priority_queue_size.set(labelValues = [$p.peerId], value = 0)
libp2p_gossipsub_non_priority_queue_size.set(labelValues = [$p.peerId], value = 0)
proc new(T: typedesc[RpcMessageQueue]): T =
return T(
sendPriorityQueue: initDeque[Future[void]](),
nonPriorityQueue: newAsyncQueue[seq[byte]](),
)
asyncSpawn p.sendEncoded(encoded)
proc new*(
T: typedesc[PubSubPeer],
peerId: PeerId,
getConn: GetConn,
onEvent: OnEvent,
codec: string,
maxMessageSize: int,
maxNumElementsInNonPriorityQueue: int = DefaultMaxNumElementsInNonPriorityQueue,
overheadRateLimitOpt: Opt[TokenBucket] = Opt.none(TokenBucket),
): T =
result = T(
T: typedesc[PubSubPeer],
peerId: PeerId,
getConn: GetConn,
onEvent: OnEvent,
codec: string,
maxMessageSize: int): T =
T(
getConn: getConn,
onEvent: onEvent,
codec: codec,
peerId: peerId,
connectedFut: newFuture[void](),
maxMessageSize: maxMessageSize,
overheadRateLimitOpt: overheadRateLimitOpt,
rpcmessagequeue: RpcMessageQueue.new(),
maxNumElementsInNonPriorityQueue: maxNumElementsInNonPriorityQueue,
maxMessageSize: maxMessageSize
)
result.sentIHaves.addFirst(default(HashSet[MessageId]))
result.iDontWants.addFirst(default(HashSet[SaltedId]))
result.startSendNonPriorityTask()

View File

@@ -7,17 +7,20 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import hashes
import chronicles, metrics, stew/[byteutils, endians2]
import
./messages,
./protobuf,
../../../peerid,
../../../peerinfo,
../../../crypto/crypto,
../../../protobuf/minprotobuf,
../../../protocols/pubsub/errors
import ./messages,
./protobuf,
../../../peerid,
../../../peerinfo,
../../../crypto/crypto,
../../../protobuf/minprotobuf,
../../../protocols/pubsub/errors
export errors, messages
@@ -26,9 +29,7 @@ logScope:
const PubSubPrefix = toBytes("libp2p-pubsub:")
declareCounter(
libp2p_pubsub_sig_verify_success, "pubsub successfully validated messages"
)
declareCounter(libp2p_pubsub_sig_verify_success, "pubsub successfully validated messages")
declareCounter(libp2p_pubsub_sig_verify_failure, "pubsub failed validated messages")
func defaultMsgIdProvider*(m: Message): Result[MessageId, ValidationResult] =
@@ -39,7 +40,7 @@ func defaultMsgIdProvider*(m: Message): Result[MessageId, ValidationResult] =
err ValidationResult.Reject
proc sign*(msg: Message, privateKey: PrivateKey): CryptoResult[seq[byte]] =
ok((?privateKey.sign(PubSubPrefix & encodeMessage(msg, false))).getBytes())
ok((? privateKey.sign(PubSubPrefix & encodeMessage(msg, false))).getBytes())
proc verify*(m: Message): bool =
if m.signature.len > 0 and m.key.len > 0:
@@ -64,26 +65,23 @@ proc init*(
data: seq[byte],
topic: string,
seqno: Option[uint64],
sign: bool = true,
): Message {.gcsafe, raises: [LPError].} =
var msg = Message(data: data, topic: topic)
sign: bool = true): Message
{.gcsafe, raises: [Defect, LPError].} =
var msg = Message(data: data, topicIDs: @[topic])
# order matters, we want to include seqno in the signature
seqno.withValue(seqn):
msg.seqno = @(seqn.toBytesBE())
if seqno.isSome:
msg.seqno = @(seqno.get().toBytesBE())
peer.withValue(peer):
if peer.isSome:
let peer = peer.get()
msg.fromPeer = peer.peerId
if sign:
msg.signature = sign(msg, peer.privateKey).expect("Couldn't sign message!")
msg.key = peer.privateKey
.getPublicKey()
.expect("Invalid private key!")
.getBytes()
.expect("Couldn't get public key bytes!")
else:
if sign:
raise (ref LPError)(msg: "Cannot sign message without peer info")
msg.key = peer.privateKey.getPublicKey().expect("Invalid private key!")
.getBytes().expect("Couldn't get public key bytes!")
elif sign:
raise (ref LPError)(msg: "Cannot sign message without peer info")
msg
@@ -92,11 +90,11 @@ proc init*(
peerId: PeerId,
data: seq[byte],
topic: string,
seqno: Option[uint64],
): Message {.gcsafe, raises: [LPError].} =
var msg = Message(data: data, topic: topic)
seqno: Option[uint64]): Message
{.gcsafe, raises: [Defect, LPError].} =
var msg = Message(data: data, topicIDs: @[topic])
msg.fromPeer = peerId
seqno.withValue(seqn):
msg.seqno = @(seqn.toBytesBE())
if seqno.isSome:
msg.seqno = @(seqno.get().toBytesBE())
msg

View File

@@ -7,106 +7,97 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import options, sequtils
import ../../../[peerid, routing_record, utility]
import "../../.."/[
peerid,
routing_record,
utility
]
export options
proc expectedFields[T](
t: typedesc[T], existingFieldNames: seq[string]
) {.raises: [CatchableError].} =
var fieldNames: seq[string]
for name, _ in fieldPairs(T()):
fieldNames &= name
if fieldNames != existingFieldNames:
fieldNames.keepIf(
proc(it: string): bool =
it notin existingFieldNames
)
raise newException(
CatchableError,
$T &
" fields changed, please search for and revise all relevant procs. New fields: " &
$fieldNames,
)
type
PeerInfoMsg* = object
peerId*: PeerId
signedPeerRecord*: seq[byte]
PeerInfoMsg* = object
peerId*: PeerId
signedPeerRecord*: seq[byte]
SubOpts* = object
subscribe*: bool
topic*: string
SubOpts* = object
subscribe*: bool
topic*: string
MessageId* = seq[byte]
MessageId* = seq[byte]
SaltedId* = object
# Salted hash of message ID - used instead of the ordinary message ID to
# avoid hash poisoning attacks and to make memory usage more predictable
# with respect to the variable-length message id
data*: MDigest[256]
Message* = object
fromPeer*: PeerId
data*: seq[byte]
seqno*: seq[byte]
topicIds*: seq[string]
signature*: seq[byte]
key*: seq[byte]
Message* = object
fromPeer*: PeerId
data*: seq[byte]
seqno*: seq[byte]
topic*: string
signature*: seq[byte]
key*: seq[byte]
ControlMessage* = object
ihave*: seq[ControlIHave]
iwant*: seq[ControlIWant]
graft*: seq[ControlGraft]
prune*: seq[ControlPrune]
ControlMessage* = object
ihave*: seq[ControlIHave]
iwant*: seq[ControlIWant]
graft*: seq[ControlGraft]
prune*: seq[ControlPrune]
idontwant*: seq[ControlIWant]
ControlIHave* = object
topicId*: string
messageIds*: seq[MessageId]
ControlIHave* = object
topicID*: string
messageIDs*: seq[MessageId]
ControlIWant* = object
messageIds*: seq[MessageId]
ControlIWant* = object
messageIDs*: seq[MessageId]
ControlGraft* = object
topicId*: string
ControlGraft* = object
topicID*: string
ControlPrune* = object
topicId*: string
peers*: seq[PeerInfoMsg]
backoff*: uint64
ControlPrune* = object
topicID*: string
peers*: seq[PeerInfoMsg]
backoff*: uint64
RPCMsg* = object
subscriptions*: seq[SubOpts]
messages*: seq[Message]
control*: Option[ControlMessage]
RPCMsg* = object
subscriptions*: seq[SubOpts]
messages*: seq[Message]
control*: Option[ControlMessage]
ping*: seq[byte]
pong*: seq[byte]
func withSubs*(T: type RPCMsg, topics: openArray[string], subscribe: bool): T =
T(subscriptions: topics.mapIt(SubOpts(subscribe: subscribe, topic: it)))
func withSubs*(
T: type RPCMsg, topics: openArray[string], subscribe: bool): T =
T(
subscriptions: topics.mapIt(SubOpts(subscribe: subscribe, topic: it)))
func shortLog*(s: ControlIHave): auto =
(topic: s.topicID.shortLog, messageIDs: mapIt(s.messageIDs, it.shortLog))
(
topicId: s.topicId.shortLog,
messageIds: mapIt(s.messageIds, it.shortLog)
)
func shortLog*(s: ControlIWant): auto =
(messageIDs: mapIt(s.messageIDs, it.shortLog))
(
messageIds: mapIt(s.messageIds, it.shortLog)
)
func shortLog*(s: ControlGraft): auto =
(topic: s.topicID.shortLog)
(
topicId: s.topicId.shortLog
)
func shortLog*(s: ControlPrune): auto =
(topic: s.topicID.shortLog)
(
topicId: s.topicId.shortLog
)
func shortLog*(c: ControlMessage): auto =
(
ihave: mapIt(c.ihave, it.shortLog),
iwant: mapIt(c.iwant, it.shortLog),
graft: mapIt(c.graft, it.shortLog),
prune: mapIt(c.prune, it.shortLog),
prune: mapIt(c.prune, it.shortLog)
)
func shortLog*(msg: Message): auto =
@@ -114,76 +105,21 @@ func shortLog*(msg: Message): auto =
fromPeer: msg.fromPeer.shortLog,
data: msg.data.shortLog,
seqno: msg.seqno.shortLog,
topic: msg.topic,
topicIds: $msg.topicIds,
signature: msg.signature.shortLog,
key: msg.key.shortLog,
key: msg.key.shortLog
)
func shortLog*(m: RPCMsg): auto =
(
subscriptions: m.subscriptions,
messages: mapIt(m.messages, it.shortLog),
control: m.control.get(ControlMessage()).shortLog,
)
static:
expectedFields(PeerInfoMsg, @["peerId", "signedPeerRecord"])
proc byteSize(peerInfo: PeerInfoMsg): int =
peerInfo.peerId.len + peerInfo.signedPeerRecord.len
static:
expectedFields(SubOpts, @["subscribe", "topic"])
proc byteSize(subOpts: SubOpts): int =
1 + subOpts.topic.len # 1 byte for the bool
static:
expectedFields(Message, @["fromPeer", "data", "seqno", "topic", "signature", "key"])
proc byteSize*(msg: Message): int =
msg.fromPeer.len + msg.data.len + msg.seqno.len + msg.signature.len + msg.key.len +
msg.topic.len
proc byteSize*(msgs: seq[Message]): int =
msgs.foldl(a + b.byteSize, 0)
static:
expectedFields(ControlIHave, @["topicID", "messageIDs"])
proc byteSize(controlIHave: ControlIHave): int =
controlIHave.topicID.len + controlIHave.messageIDs.foldl(a + b.len, 0)
proc byteSize*(ihaves: seq[ControlIHave]): int =
ihaves.foldl(a + b.byteSize, 0)
static:
expectedFields(ControlIWant, @["messageIDs"])
proc byteSize(controlIWant: ControlIWant): int =
controlIWant.messageIDs.foldl(a + b.len, 0)
proc byteSize*(iwants: seq[ControlIWant]): int =
iwants.foldl(a + b.byteSize, 0)
static:
expectedFields(ControlGraft, @["topicID"])
proc byteSize(controlGraft: ControlGraft): int =
controlGraft.topicID.len
static:
expectedFields(ControlPrune, @["topicID", "peers", "backoff"])
proc byteSize(controlPrune: ControlPrune): int =
controlPrune.topicID.len + controlPrune.peers.foldl(a + b.byteSize, 0) + 8
# 8 bytes for uint64
static:
expectedFields(ControlMessage, @["ihave", "iwant", "graft", "prune", "idontwant"])
proc byteSize(control: ControlMessage): int =
control.ihave.foldl(a + b.byteSize, 0) + control.iwant.foldl(a + b.byteSize, 0) +
control.graft.foldl(a + b.byteSize, 0) + control.prune.foldl(a + b.byteSize, 0) +
control.idontwant.foldl(a + b.byteSize, 0)
static:
expectedFields(RPCMsg, @["subscriptions", "messages", "control", "ping", "pong"])
proc byteSize*(rpc: RPCMsg): int =
result =
rpc.subscriptions.foldl(a + b.byteSize, 0) + byteSize(rpc.messages) + rpc.ping.len +
rpc.pong.len
rpc.control.withValue(ctrl):
result += ctrl.byteSize
if m.control.isSome:
(
subscriptions: m.subscriptions,
messages: mapIt(m.messages, it.shortLog),
control: m.control.get().shortLog
)
else:
(
subscriptions: m.subscriptions,
messages: mapIt(m.messages, it.shortLog),
control: ControlMessage().shortLog
)

Some files were not shown because too many files have changed in this diff Show More