mirror of
https://github.com/vacp2p/mix.git
synced 2026-01-09 23:08:09 -05:00
Update to a GC-safe tag manager
This commit is contained in:
@@ -1,58 +1 @@
|
||||
import os, std/asyncdispatch, strutils
|
||||
import libp2p
|
||||
import libp2p/[switch,
|
||||
stream/connection,
|
||||
protocols/protocol,
|
||||
crypto/crypto,
|
||||
peerinfo,
|
||||
multiaddress,
|
||||
builders,
|
||||
peerid]
|
||||
import config, curve25519, network_manager, serialization, sphinx
|
||||
|
||||
const MixProtocolID = "/mix/1.0.0"
|
||||
|
||||
type
|
||||
MixProto = ref object of LPProtocol
|
||||
privateKey: FieldElement
|
||||
publicKey: FieldElement
|
||||
nm: NetworkManager
|
||||
|
||||
proc newMixProto(nm: NetworkManager, privateKey, publicKey: FieldElement): MixProto =
|
||||
result = MixProto(nm: nm, privateKey: privateKey, publicKey: publicKey)
|
||||
proc handle(conn: Connection, proto: string) {.async.} =
|
||||
while true:
|
||||
var receivedBytes = await conn.readLp(packetSize)
|
||||
|
||||
if receivedBytes.len == 0:
|
||||
break # No data, end of stream
|
||||
|
||||
while receivedBytes.len >= packetSize:
|
||||
let packet = receivedBytes[0..packetSize-1]
|
||||
receivedBytes = receivedBytes[packetSize..^1] # Remove the processed packet
|
||||
|
||||
# Process the packet
|
||||
let (nextHop, delay, processedPkt, status) = processSphinxPacket(packet, privateKey)
|
||||
|
||||
case status:
|
||||
of Success:
|
||||
if not ((nextHop == Hop()) and (delay == @[]) and (status == Success)):
|
||||
# Add delay
|
||||
let delayMillis = (delay[0] shl 8) or delay[1]
|
||||
sleep(int(delayMillis))
|
||||
|
||||
# Forward to next hop
|
||||
let multiAddr = cast[string](getHop(nextHop))
|
||||
let nextHopConn = await dialToNextHop(nm, nextHopAddr, MixProtocolID)
|
||||
await nextHopConn.writeLp(processedPkt)
|
||||
await nextHopConn.close()
|
||||
|
||||
of Duplicate:
|
||||
discard
|
||||
of InvalidMAC:
|
||||
discard
|
||||
of InvalidPoW:
|
||||
discard
|
||||
|
||||
# Close the current connection after processing
|
||||
await conn.close()
|
||||
@@ -139,7 +139,7 @@ proc wrapInSphinxPacket*( msg: Message, publicKeys: openArray[FieldElement], del
|
||||
let sphinxPacket = initSphinxPacket(initHeader(alpha_0, beta_0, gamma_0), delta_0)
|
||||
return serializeSphinxPacket(sphinxPacket)
|
||||
|
||||
proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement): (Hop, seq[byte], seq[byte], ProcessingStatus) =
|
||||
proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement, tm: var TagManager): (Hop, seq[byte], seq[byte], ProcessingStatus) =
|
||||
# Deserialize the Sphinx packet
|
||||
let sphinxPacket = deserializeSphinxPacket(serSphinxPacket)
|
||||
let (header, payload) = getSphinxPacket(sphinxPacket)
|
||||
@@ -150,7 +150,7 @@ proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement):
|
||||
let sBytes = fieldElementToBytes(s)
|
||||
|
||||
# Check if the tag has been seen
|
||||
if isTagSeen(s):
|
||||
if isTagSeen(tm, s):
|
||||
# If the tag is in the seen list, discard the message
|
||||
return (Hop(), @[], @[], Duplicate)
|
||||
|
||||
@@ -162,7 +162,7 @@ proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement):
|
||||
return (Hop(), @[], @[], InvalidMAC)
|
||||
|
||||
# Store the tag as seen
|
||||
addTag(s)
|
||||
addTag(tm, s)
|
||||
|
||||
# Derive AES key and IV
|
||||
let beta_aes_key = kdf(deriveKeyMaterial("aes_key", sBytes))
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
import tables, curve25519
|
||||
import tables, curve25519, locks
|
||||
|
||||
# Define a global variable for the tag manager
|
||||
var seenTags*: Table[FieldElement, bool]
|
||||
type
|
||||
TagManager* = ref object
|
||||
lock: Lock
|
||||
seenTags: Table[FieldElement, bool]
|
||||
|
||||
# Initialize the tag manager
|
||||
proc initTagManager*() =
|
||||
seenTags = initTable[FieldElement, bool]()
|
||||
proc initTagManager*(): TagManager =
|
||||
new(result)
|
||||
result.seenTags = initTable[FieldElement, bool]()
|
||||
initLock(result.lock)
|
||||
|
||||
# Add a tag to the seen list
|
||||
proc addTag*(tag: FieldElement) =
|
||||
seenTags[tag] = true
|
||||
proc addTag*(tm: TagManager, tag: FieldElement) {.gcsafe.} =
|
||||
withLock tm.lock:
|
||||
tm.seenTags[tag] = true
|
||||
|
||||
# Check if a tag has been seen
|
||||
proc isTagSeen*(tag: FieldElement): bool =
|
||||
seenTags.contains(tag)
|
||||
proc isTagSeen*(tm: TagManager, tag: FieldElement): bool {.gcsafe.} =
|
||||
withLock tm.lock:
|
||||
result = tm.seenTags.contains(tag)
|
||||
|
||||
# Remove a tag from the seen list
|
||||
proc removeTag*(tag: FieldElement) =
|
||||
seenTags.del(tag)
|
||||
proc removeTag*(tm: TagManager, tag: FieldElement) {.gcsafe.} =
|
||||
withLock tm.lock:
|
||||
tm.seenTags.del(tag)
|
||||
|
||||
proc clearTags*(tm: TagManager) {.gcsafe.} =
|
||||
withLock tm.lock:
|
||||
tm.seenTags.clear()
|
||||
@@ -52,26 +52,30 @@ proc createDummyData(): (Message, seq[FieldElement], seq[FieldElement], seq[seq[
|
||||
|
||||
# Unit tests for sphinx.nim
|
||||
suite "Sphinx Tests":
|
||||
var tm: TagManager
|
||||
|
||||
setup:
|
||||
tm = initTagManager()
|
||||
|
||||
teardown:
|
||||
clearTags(tm)
|
||||
|
||||
test "sphinx_wrap_and_process":
|
||||
# Initialize tag manager
|
||||
initTagManager()
|
||||
|
||||
let (message, privateKeys, publicKeys, delay, hops) = createDummyData()
|
||||
let packet = wrapInSphinxPacket(message, publicKeys, delay, hops)
|
||||
assert packet.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
|
||||
|
||||
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0])
|
||||
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0], tm)
|
||||
assert status1 == Success, "Processing status should be Success"
|
||||
assert processedPacket1.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
|
||||
assert not ifExit(address1, delay1, processedPacket1, status1), "Packet processing failed"
|
||||
|
||||
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1])
|
||||
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1], tm)
|
||||
assert status2 == Success, "Processing status should be Success"
|
||||
assert processedPacket2.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
|
||||
assert not ifExit(address2, delay2, processedPacket2, status2), "Packet processing failed"
|
||||
|
||||
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2])
|
||||
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2], tm)
|
||||
assert status3 == Success, "Processing status should be Success"
|
||||
assert ifExit(address3, delay3, processedPacket3, status3), "Packet processing failed"
|
||||
|
||||
@@ -91,7 +95,7 @@ suite "Sphinx Tests":
|
||||
# Corrupt the MAC for testing
|
||||
var tamperedPacket = packet
|
||||
tamperedPacket[0] = packet[0] xor 0x01
|
||||
let (_, _, _, status) = processSphinxPacket(tamperedPacket, privateKeys[0])
|
||||
let (_, _, _, status) = processSphinxPacket(tamperedPacket, privateKeys[0], tm)
|
||||
assert status == InvalidMAC, "Processing status should be InvalidMAC"
|
||||
|
||||
test "sphinx_process_duplicate_tag":
|
||||
@@ -100,9 +104,9 @@ suite "Sphinx Tests":
|
||||
assert packet.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
|
||||
|
||||
# Process the packet twice to test duplicate tag handling
|
||||
let (_, _, _, status1) = processSphinxPacket(packet, privateKeys[0])
|
||||
let (_, _, _, status1) = processSphinxPacket(packet, privateKeys[0], tm)
|
||||
assert status1 == Success, "Processing status should be Success"
|
||||
let (_, _, _, status2) = processSphinxPacket(packet, privateKeys[0])
|
||||
let (_, _, _, status2) = processSphinxPacket(packet, privateKeys[0], tm)
|
||||
assert status2 == Duplicate, "Processing status should be Duplicate"
|
||||
|
||||
test "sphinx_wrap_and_process_message_sizes":
|
||||
@@ -117,17 +121,17 @@ suite "Sphinx Tests":
|
||||
let packet = wrapInSphinxPacket(initMessage(paddedMessage), publicKeys, delay, hops)
|
||||
assert packet.len == packetSize, "Packet size be exactly " & $packetSize & " bytes for message size " & $messageSize
|
||||
|
||||
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0])
|
||||
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0], tm)
|
||||
assert status1 == Success, "Processing status should be Success"
|
||||
assert processedPacket1.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
|
||||
assert not ifExit(address1, delay1, processedPacket1, status1), "Packet processing failed"
|
||||
|
||||
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1])
|
||||
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1], tm)
|
||||
assert status2 == Success, "Processing status should be Success"
|
||||
assert processedPacket2.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
|
||||
assert not ifExit(address2, delay2, processedPacket2, status2), "Packet processing failed"
|
||||
|
||||
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2])
|
||||
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2], tm)
|
||||
assert status3 == Success, "Processing status should be Success"
|
||||
assert ifExit(address3, delay3, processedPacket3, status3), "Packet processing failed"
|
||||
|
||||
|
||||
@@ -1,42 +1,44 @@
|
||||
import unittest, ../src/tag_manager, ../src/curve25519, tables
|
||||
import unittest, ../src/tag_manager, ../src/curve25519
|
||||
|
||||
suite "tag_manager_tests":
|
||||
# Setup to initialize the tag manager before running tests
|
||||
initTagManager()
|
||||
var tm: TagManager
|
||||
|
||||
test "add_and_check_tag":
|
||||
let tag = generateRandomFieldElement()
|
||||
addTag(tag)
|
||||
check isTagSeen(tag)
|
||||
let nonexistentTag = generateRandomFieldElement()
|
||||
check not isTagSeen(nonexistentTag)
|
||||
setup:
|
||||
tm = initTagManager()
|
||||
|
||||
test "remove_tag":
|
||||
let tag = generateRandomFieldElement()
|
||||
addTag(tag)
|
||||
check isTagSeen(tag)
|
||||
removeTag(tag)
|
||||
check not isTagSeen(tag)
|
||||
teardown:
|
||||
clearTags(tm)
|
||||
|
||||
test "check_tag_presence":
|
||||
let tag = generateRandomFieldElement()
|
||||
check not isTagSeen(tag)
|
||||
addTag(tag)
|
||||
check isTagSeen(tag)
|
||||
removeTag(tag)
|
||||
check not isTagSeen(tag)
|
||||
test "add_and_check_tag":
|
||||
let tag = generateRandomFieldElement()
|
||||
addTag(tm, tag)
|
||||
check isTagSeen(tm, tag)
|
||||
let nonexistentTag = generateRandomFieldElement()
|
||||
check not isTagSeen(tm, nonexistentTag)
|
||||
|
||||
test "handle_multiple_tags":
|
||||
let tag1 = generateRandomFieldElement()
|
||||
let tag2 = generateRandomFieldElement()
|
||||
addTag(tag1)
|
||||
addTag(tag2)
|
||||
check isTagSeen(tag1)
|
||||
check isTagSeen(tag2)
|
||||
removeTag(tag1)
|
||||
removeTag(tag2)
|
||||
check not isTagSeen(tag1)
|
||||
check not isTagSeen(tag2)
|
||||
test "remove_tag":
|
||||
let tag = generateRandomFieldElement()
|
||||
addTag(tm, tag)
|
||||
check isTagSeen(tm, tag)
|
||||
removeTag(tm, tag)
|
||||
check not isTagSeen(tm, tag)
|
||||
|
||||
# Teardown to clean up after running tests
|
||||
clear(seenTags)
|
||||
test "check_tag_presence":
|
||||
let tag = generateRandomFieldElement()
|
||||
check not isTagSeen(tm, tag)
|
||||
addTag(tm, tag)
|
||||
check isTagSeen(tm, tag)
|
||||
removeTag(tm, tag)
|
||||
check not isTagSeen(tm, tag)
|
||||
|
||||
test "handle_multiple_tags":
|
||||
let tag1 = generateRandomFieldElement()
|
||||
let tag2 = generateRandomFieldElement()
|
||||
addTag(tm, tag1)
|
||||
addTag(tm, tag2)
|
||||
check isTagSeen(tm, tag1)
|
||||
check isTagSeen(tm, tag2)
|
||||
removeTag(tm, tag1)
|
||||
removeTag(tm, tag2)
|
||||
check not isTagSeen(tm, tag1)
|
||||
check not isTagSeen(tm, tag2)
|
||||
Reference in New Issue
Block a user