mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2026-01-09 02:38:19 -05:00
feat(kad-dht): routing table (#1454)
This commit is contained in:
4
libp2p/protocols/kademlia/consts.nim
Normal file
4
libp2p/protocols/kademlia/consts.nim
Normal file
@@ -0,0 +1,4 @@
|
||||
const
|
||||
IdLength* = 32 # 256-bit IDs
|
||||
k* = 20 # replication parameter
|
||||
maxBuckets* = 256
|
||||
48
libp2p/protocols/kademlia/keys.nim
Normal file
48
libp2p/protocols/kademlia/keys.nim
Normal file
@@ -0,0 +1,48 @@
|
||||
import ../../peerid
|
||||
import ./consts
|
||||
import chronicles
|
||||
import stew/byteutils
|
||||
|
||||
type
|
||||
KeyType* {.pure.} = enum
|
||||
Unhashed
|
||||
Raw
|
||||
PeerId
|
||||
|
||||
Key* = object
|
||||
case kind*: KeyType
|
||||
of KeyType.PeerId:
|
||||
peerId*: PeerId
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
data*: array[IdLength, byte]
|
||||
|
||||
proc toKey*(s: seq[byte]): Key =
|
||||
doAssert s.len == IdLength
|
||||
var data: array[IdLength, byte]
|
||||
for i in 0 ..< IdLength:
|
||||
data[i] = s[i]
|
||||
return Key(kind: KeyType.Raw, data: data)
|
||||
|
||||
proc toKey*(p: PeerId): Key =
|
||||
return Key(kind: KeyType.PeerId, peerId: p)
|
||||
|
||||
proc getBytes*(k: Key): seq[byte] =
|
||||
return
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
k.peerId.getBytes()
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
@(k.data)
|
||||
|
||||
template `==`*(a, b: Key): bool =
|
||||
a.getBytes() == b.getBytes() and a.kind == b.kind
|
||||
|
||||
proc shortLog*(k: Key): string =
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
"PeerId:" & $k.peerId
|
||||
of KeyType.Raw, KeyType.Unhashed:
|
||||
$k.kind & ":" & toHex(k.data)
|
||||
|
||||
chronicles.formatIt(Key):
|
||||
shortLog(it)
|
||||
129
libp2p/protocols/kademlia/routingtable.nim
Normal file
129
libp2p/protocols/kademlia/routingtable.nim
Normal file
@@ -0,0 +1,129 @@
|
||||
import algorithm
|
||||
import bearssl/rand
|
||||
import chronos
|
||||
import chronicles
|
||||
import ./consts
|
||||
import ./keys
|
||||
import ./xordistance
|
||||
import ../../peerid
|
||||
import sequtils
|
||||
|
||||
logScope:
|
||||
topics = "kad-dht rtable"
|
||||
|
||||
type
|
||||
NodeEntry* = object
|
||||
nodeId*: Key
|
||||
lastSeen*: Moment
|
||||
|
||||
Bucket* = object
|
||||
peers*: seq[NodeEntry]
|
||||
|
||||
RoutingTable* = ref object
|
||||
selfId*: Key
|
||||
buckets*: seq[Bucket]
|
||||
|
||||
proc init*(T: typedesc[RoutingTable], selfId: Key): T =
|
||||
return RoutingTable(selfId: selfId, buckets: @[])
|
||||
|
||||
proc bucketIndex*(selfId, key: Key): int =
|
||||
return xorDistance(selfId, key).leadingZeros
|
||||
|
||||
proc peerIndexInBucket(bucket: var Bucket, nodeId: Key): Opt[int] =
|
||||
for i, p in bucket.peers:
|
||||
if p.nodeId == nodeId:
|
||||
return Opt.some(i)
|
||||
return Opt.none(int)
|
||||
|
||||
proc insert*(rtable: var RoutingTable, nodeId: Key): bool =
|
||||
if nodeId == rtable.selfId:
|
||||
return false # No self insertion
|
||||
|
||||
let idx = bucketIndex(rtable.selfId, nodeId)
|
||||
if idx >= maxBuckets:
|
||||
trace "cannot insert node. max buckets have been reached",
|
||||
nodeId, bucketIdx = idx, maxBuckets
|
||||
return false
|
||||
|
||||
if idx >= rtable.buckets.len:
|
||||
# expand buckets lazily if needed
|
||||
rtable.buckets.setLen(idx + 1)
|
||||
|
||||
var bucket = rtable.buckets[idx]
|
||||
let keyx = peerIndexInBucket(bucket, nodeId)
|
||||
if keyx.isSome:
|
||||
bucket.peers[keyx.unsafeValue].lastSeen = Moment.now()
|
||||
elif bucket.peers.len < k:
|
||||
bucket.peers.add(NodeEntry(nodeId: nodeId, lastSeen: Moment.now()))
|
||||
else:
|
||||
# TODO: eviction policy goes here, rn we drop the node
|
||||
trace "cannot insert node in bucket, dropping node",
|
||||
nodeId, bucket = k, bucketIdx = idx
|
||||
return false
|
||||
|
||||
rtable.buckets[idx] = bucket
|
||||
return true
|
||||
|
||||
proc insert*(rtable: var RoutingTable, peerId: PeerId): bool =
|
||||
insert(rtable, peerId.toKey())
|
||||
|
||||
proc findClosest*(rtable: RoutingTable, targetId: Key, count: int): seq[Key] =
|
||||
var allNodes: seq[Key] = @[]
|
||||
|
||||
for bucket in rtable.buckets:
|
||||
for p in bucket.peers:
|
||||
allNodes.add(p.nodeId)
|
||||
|
||||
allNodes.sort(
|
||||
proc(a, b: Key): int =
|
||||
cmp(xorDistance(a, targetId), xorDistance(b, targetId))
|
||||
)
|
||||
|
||||
return allNodes[0 ..< min(count, allNodes.len)]
|
||||
|
||||
proc findClosestPeers*(rtable: RoutingTable, targetId: Key, count: int): seq[PeerId] =
|
||||
findClosest(rtable, targetId, count).mapIt(it.peerId)
|
||||
|
||||
proc isStale*(bucket: Bucket): bool =
|
||||
if bucket.peers.len == 0:
|
||||
return true
|
||||
for p in bucket.peers:
|
||||
if Moment.now() - p.lastSeen > 30.minutes:
|
||||
return true
|
||||
return false
|
||||
|
||||
proc randomKeyInBucketRange*(
|
||||
selfId: Key, bucketIndex: int, rng: ref HmacDrbgContext
|
||||
): Key =
|
||||
var raw = selfId.getBytes()
|
||||
|
||||
# zero out higher bits
|
||||
for i in 0 ..< bucketIndex:
|
||||
let byteIdx = i div 8
|
||||
let bitInByte = 7 - (i mod 8)
|
||||
raw[byteIdx] = raw[byteIdx] and not (1'u8 shl bitInByte)
|
||||
|
||||
# flip the target bit
|
||||
let tgtByte = bucketIndex div 8
|
||||
let tgtBitInByte = 7 - (bucketIndex mod 8)
|
||||
raw[tgtByte] = raw[tgtByte] xor (1'u8 shl tgtBitInByte)
|
||||
|
||||
# randomize all less significant bits
|
||||
let totalBits = raw.len * 8
|
||||
let lsbStart = bucketIndex + 1
|
||||
let lsbBytes = (totalBits - lsbStart + 7) div 8
|
||||
var randomBuf = newSeq[byte](lsbBytes)
|
||||
hmacDrbgGenerate(rng[], randomBuf)
|
||||
|
||||
for i in lsbStart ..< totalBits:
|
||||
let byteIdx = i div 8
|
||||
let bitInByte = 7 - (i mod 8)
|
||||
let lsbByte = (i - lsbStart) div 8
|
||||
let lsbBit = 7 - ((i - lsbStart) mod 8)
|
||||
let randBit = (randomBuf[lsbByte] shr lsbBit) and 1
|
||||
if randBit == 1:
|
||||
raw[byteIdx] = raw[byteIdx] or (1'u8 shl bitInByte)
|
||||
else:
|
||||
raw[byteIdx] = raw[byteIdx] and not (1'u8 shl bitInByte)
|
||||
|
||||
return raw.toKey()
|
||||
55
libp2p/protocols/kademlia/xordistance.nim
Normal file
55
libp2p/protocols/kademlia/xordistance.nim
Normal file
@@ -0,0 +1,55 @@
|
||||
import ./consts
|
||||
import ./keys
|
||||
import nimcrypto/sha2
|
||||
import ../../peerid
|
||||
|
||||
type XorDistance* = array[IdLength, byte]
|
||||
|
||||
proc countLeadingZeroBits*(b: byte): int =
|
||||
for i in 0 .. 7:
|
||||
if (b and (0x80'u8 shr i)) != 0:
|
||||
return i
|
||||
return 8
|
||||
|
||||
proc leadingZeros*(dist: XorDistance): int =
|
||||
for i in 0 ..< dist.len:
|
||||
if dist[i] != 0:
|
||||
return i * 8 + countLeadingZeroBits(dist[i])
|
||||
return dist.len * 8
|
||||
|
||||
proc cmp*(a, b: XorDistance): int =
|
||||
for i in 0 ..< IdLength:
|
||||
if a[i] < b[i]:
|
||||
return -1
|
||||
elif a[i] > b[i]:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
proc `<`*(a, b: XorDistance): bool =
|
||||
cmp(a, b) < 0
|
||||
|
||||
proc `<=`*(a, b: XorDistance): bool =
|
||||
cmp(a, b) <= 0
|
||||
|
||||
proc hashFor(k: Key): seq[byte] =
|
||||
return
|
||||
@(
|
||||
case k.kind
|
||||
of KeyType.PeerId:
|
||||
sha256.digest(k.peerId.getBytes()).data
|
||||
of KeyType.Raw:
|
||||
sha256.digest(k.data).data
|
||||
of KeyType.Unhashed:
|
||||
k.data
|
||||
)
|
||||
|
||||
proc xorDistance*(a, b: Key): XorDistance =
|
||||
let hashA = a.hashFor()
|
||||
let hashB = b.hashFor()
|
||||
var response: XorDistance
|
||||
for i in 0 ..< hashA.len:
|
||||
response[i] = hashA[i] xor hashB[i]
|
||||
return response
|
||||
|
||||
proc xorDistance*(a: PeerId, b: Key): XorDistance =
|
||||
xorDistance(a.toKey(), b)
|
||||
83
tests/kademlia/testroutingtable.nim
Normal file
83
tests/kademlia/testroutingtable.nim
Normal file
@@ -0,0 +1,83 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import unittest
|
||||
import chronos
|
||||
import ../../libp2p/crypto/crypto
|
||||
import ../../libp2p/protocols/kademlia/[routingtable, consts, keys]
|
||||
|
||||
proc testKey*(x: byte): Key =
|
||||
var buf: array[IdLength, byte]
|
||||
buf[31] = x
|
||||
return Key(kind: KeyType.Unhashed, data: buf)
|
||||
|
||||
let rng = crypto.newRng()
|
||||
|
||||
suite "routing table":
|
||||
test "inserts single key in correct bucket":
|
||||
let selfId = testKey(0)
|
||||
var rt = RoutingTable.init(selfId)
|
||||
let other = testKey(0b10000000)
|
||||
discard rt.insert(other)
|
||||
|
||||
let idx = bucketIndex(selfId, other)
|
||||
check:
|
||||
rt.buckets.len > idx
|
||||
rt.buckets[idx].peers.len == 1
|
||||
rt.buckets[idx].peers[0].nodeId == other
|
||||
|
||||
test "does not insert beyond capacity":
|
||||
let selfId = testKey(0)
|
||||
var rt = RoutingTable.init(selfId)
|
||||
let targetBucket = 6
|
||||
for _ in 0 ..< k + 5:
|
||||
var kid = randomKeyInBucketRange(selfId, targetBucket, rng)
|
||||
kid.kind = KeyType.Unhashed
|
||||
# Overriding so we don't use sha for comparing xor distances
|
||||
discard rt.insert(kid)
|
||||
|
||||
check targetBucket < rt.buckets.len
|
||||
let bucket = rt.buckets[targetBucket]
|
||||
check bucket.peers.len <= k
|
||||
|
||||
test "findClosest returns sorted keys":
|
||||
let selfId = testKey(0)
|
||||
var rt = RoutingTable.init(selfId)
|
||||
let ids = @[testKey(1), testKey(2), testKey(3), testKey(4), testKey(5)]
|
||||
for id in ids:
|
||||
discard rt.insert(id)
|
||||
|
||||
let res = rt.findClosest(testKey(1), 3)
|
||||
|
||||
check:
|
||||
res.len == 3
|
||||
res == @[testKey(1), testKey(3), testKey(2)]
|
||||
|
||||
test "isStale returns true for empty or old keys":
|
||||
var bucket: Bucket
|
||||
check isStale(bucket) == true
|
||||
|
||||
bucket.peers = @[NodeEntry(nodeId: testKey(1), lastSeen: Moment.now() - 40.minutes)]
|
||||
check isStale(bucket) == true
|
||||
|
||||
bucket.peers = @[NodeEntry(nodeId: testKey(1), lastSeen: Moment.now())]
|
||||
check isStale(bucket) == false
|
||||
|
||||
test "randomKeyInBucketRange returns id at correct distance":
|
||||
let selfId = testKey(0)
|
||||
let targetBucket = 3
|
||||
var rid = randomKeyInBucketRange(selfId, targetBucket, rng)
|
||||
rid.kind = KeyType.Unhashed
|
||||
# Overriding so we don't use sha for comparing xor distances
|
||||
let idx = bucketIndex(selfId, rid)
|
||||
check:
|
||||
idx == targetBucket
|
||||
rid != selfId
|
||||
54
tests/kademlia/testxordistance.nim
Normal file
54
tests/kademlia/testxordistance.nim
Normal file
@@ -0,0 +1,54 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import unittest
|
||||
import chronos
|
||||
import ../../libp2p/protocols/kademlia/[consts, keys, xordistance]
|
||||
|
||||
suite "xor distance":
|
||||
test "countLeadingZeroBits works":
|
||||
check countLeadingZeroBits(0b00000000'u8) == 8
|
||||
check countLeadingZeroBits(0b10000000'u8) == 0
|
||||
check countLeadingZeroBits(0b01000000'u8) == 1
|
||||
check countLeadingZeroBits(0b00000001'u8) == 7
|
||||
|
||||
test "leadingZeros of xor distance":
|
||||
var d: XorDistance
|
||||
for i in 0 ..< IdLength:
|
||||
d[i] = 0
|
||||
check leadingZeros(d) == IdLength * 8
|
||||
|
||||
d[0] = 0b00010000
|
||||
check leadingZeros(d) == 3
|
||||
|
||||
d[0] = 0
|
||||
d[1] = 0b00100000
|
||||
check leadingZeros(d) == 10
|
||||
|
||||
test "xorDistance of identical keys is zero":
|
||||
let k = @[
|
||||
1'u8, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6,
|
||||
7, 8, 9, 0, 1, 2,
|
||||
].toKey()
|
||||
let dist = xorDistance(k, k)
|
||||
check:
|
||||
leadingZeros(dist) == IdLength * 8
|
||||
dist == default(XorDistance)
|
||||
|
||||
test "cmp gives correct order":
|
||||
var a: XorDistance
|
||||
var b: XorDistance
|
||||
a[0] = 0x01
|
||||
b[0] = 0x02
|
||||
check a < b
|
||||
check cmp(a, b) == -1
|
||||
check cmp(b, a) == 1
|
||||
check cmp(a, a) == 0
|
||||
@@ -34,4 +34,4 @@ import
|
||||
testdiscovery, testyamux, testautonat, testautonatservice, testautorelay, testdcutr,
|
||||
testhpservice, testutility, testhelpers, testwildcardresolverservice, testperf
|
||||
|
||||
import kademlia/testencoding
|
||||
import kademlia/[testencoding, testroutingtable]
|
||||
|
||||
Reference in New Issue
Block a user