Update to a GC-safe tag manager

This commit is contained in:
akshaya
2024-09-20 14:39:16 -04:00
parent 07517d8cd6
commit 88f7a1ad6d
5 changed files with 78 additions and 122 deletions

View File

@@ -1,58 +1 @@
import os, std/asyncdispatch, strutils
import libp2p
import libp2p/[switch,
stream/connection,
protocols/protocol,
crypto/crypto,
peerinfo,
multiaddress,
builders,
peerid]
import config, curve25519, network_manager, serialization, sphinx
const MixProtocolID = "/mix/1.0.0"
type
MixProto = ref object of LPProtocol
privateKey: FieldElement
publicKey: FieldElement
nm: NetworkManager
proc newMixProto(nm: NetworkManager, privateKey, publicKey: FieldElement): MixProto =
result = MixProto(nm: nm, privateKey: privateKey, publicKey: publicKey)
proc handle(conn: Connection, proto: string) {.async.} =
while true:
var receivedBytes = await conn.readLp(packetSize)
if receivedBytes.len == 0:
break # No data, end of stream
while receivedBytes.len >= packetSize:
let packet = receivedBytes[0..packetSize-1]
receivedBytes = receivedBytes[packetSize..^1] # Remove the processed packet
# Process the packet
let (nextHop, delay, processedPkt, status) = processSphinxPacket(packet, privateKey)
case status:
of Success:
if not ((nextHop == Hop()) and (delay == @[]) and (status == Success)):
# Add delay
let delayMillis = (delay[0] shl 8) or delay[1]
sleep(int(delayMillis))
# Forward to next hop
let multiAddr = cast[string](getHop(nextHop))
let nextHopConn = await dialToNextHop(nm, nextHopAddr, MixProtocolID)
await nextHopConn.writeLp(processedPkt)
await nextHopConn.close()
of Duplicate:
discard
of InvalidMAC:
discard
of InvalidPoW:
discard
# Close the current connection after processing
await conn.close()

View File

@@ -139,7 +139,7 @@ proc wrapInSphinxPacket*( msg: Message, publicKeys: openArray[FieldElement], del
let sphinxPacket = initSphinxPacket(initHeader(alpha_0, beta_0, gamma_0), delta_0)
return serializeSphinxPacket(sphinxPacket)
proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement): (Hop, seq[byte], seq[byte], ProcessingStatus) =
proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement, tm: var TagManager): (Hop, seq[byte], seq[byte], ProcessingStatus) =
# Deserialize the Sphinx packet
let sphinxPacket = deserializeSphinxPacket(serSphinxPacket)
let (header, payload) = getSphinxPacket(sphinxPacket)
@@ -150,7 +150,7 @@ proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement):
let sBytes = fieldElementToBytes(s)
# Check if the tag has been seen
if isTagSeen(s):
if isTagSeen(tm, s):
# If the tag is in the seen list, discard the message
return (Hop(), @[], @[], Duplicate)
@@ -162,7 +162,7 @@ proc processSphinxPacket*(serSphinxPacket: seq[byte], privateKey: FieldElement):
return (Hop(), @[], @[], InvalidMAC)
# Store the tag as seen
addTag(s)
addTag(tm, s)
# Derive AES key and IV
let beta_aes_key = kdf(deriveKeyMaterial("aes_key", sBytes))

View File

@@ -1,20 +1,27 @@
import tables, curve25519
import tables, curve25519, locks
# Define a global variable for the tag manager
var seenTags*: Table[FieldElement, bool]
type
TagManager* = ref object
lock: Lock
seenTags: Table[FieldElement, bool]
# Initialize the tag manager
proc initTagManager*() =
seenTags = initTable[FieldElement, bool]()
proc initTagManager*(): TagManager =
new(result)
result.seenTags = initTable[FieldElement, bool]()
initLock(result.lock)
# Add a tag to the seen list
proc addTag*(tag: FieldElement) =
seenTags[tag] = true
proc addTag*(tm: TagManager, tag: FieldElement) {.gcsafe.} =
withLock tm.lock:
tm.seenTags[tag] = true
# Check if a tag has been seen
proc isTagSeen*(tag: FieldElement): bool =
seenTags.contains(tag)
proc isTagSeen*(tm: TagManager, tag: FieldElement): bool {.gcsafe.} =
withLock tm.lock:
result = tm.seenTags.contains(tag)
# Remove a tag from the seen list
proc removeTag*(tag: FieldElement) =
seenTags.del(tag)
proc removeTag*(tm: TagManager, tag: FieldElement) {.gcsafe.} =
withLock tm.lock:
tm.seenTags.del(tag)
proc clearTags*(tm: TagManager) {.gcsafe.} =
withLock tm.lock:
tm.seenTags.clear()

View File

@@ -52,26 +52,30 @@ proc createDummyData(): (Message, seq[FieldElement], seq[FieldElement], seq[seq[
# Unit tests for sphinx.nim
suite "Sphinx Tests":
var tm: TagManager
setup:
tm = initTagManager()
teardown:
clearTags(tm)
test "sphinx_wrap_and_process":
# Initialize tag manager
initTagManager()
let (message, privateKeys, publicKeys, delay, hops) = createDummyData()
let packet = wrapInSphinxPacket(message, publicKeys, delay, hops)
assert packet.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0])
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0], tm)
assert status1 == Success, "Processing status should be Success"
assert processedPacket1.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
assert not ifExit(address1, delay1, processedPacket1, status1), "Packet processing failed"
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1])
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1], tm)
assert status2 == Success, "Processing status should be Success"
assert processedPacket2.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
assert not ifExit(address2, delay2, processedPacket2, status2), "Packet processing failed"
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2])
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2], tm)
assert status3 == Success, "Processing status should be Success"
assert ifExit(address3, delay3, processedPacket3, status3), "Packet processing failed"
@@ -91,7 +95,7 @@ suite "Sphinx Tests":
# Corrupt the MAC for testing
var tamperedPacket = packet
tamperedPacket[0] = packet[0] xor 0x01
let (_, _, _, status) = processSphinxPacket(tamperedPacket, privateKeys[0])
let (_, _, _, status) = processSphinxPacket(tamperedPacket, privateKeys[0], tm)
assert status == InvalidMAC, "Processing status should be InvalidMAC"
test "sphinx_process_duplicate_tag":
@@ -100,9 +104,9 @@ suite "Sphinx Tests":
assert packet.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
# Process the packet twice to test duplicate tag handling
let (_, _, _, status1) = processSphinxPacket(packet, privateKeys[0])
let (_, _, _, status1) = processSphinxPacket(packet, privateKeys[0], tm)
assert status1 == Success, "Processing status should be Success"
let (_, _, _, status2) = processSphinxPacket(packet, privateKeys[0])
let (_, _, _, status2) = processSphinxPacket(packet, privateKeys[0], tm)
assert status2 == Duplicate, "Processing status should be Duplicate"
test "sphinx_wrap_and_process_message_sizes":
@@ -117,17 +121,17 @@ suite "Sphinx Tests":
let packet = wrapInSphinxPacket(initMessage(paddedMessage), publicKeys, delay, hops)
assert packet.len == packetSize, "Packet size be exactly " & $packetSize & " bytes for message size " & $messageSize
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0])
let (address1, delay1, processedPacket1, status1) = processSphinxPacket(packet, privateKeys[0], tm)
assert status1 == Success, "Processing status should be Success"
assert processedPacket1.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
assert not ifExit(address1, delay1, processedPacket1, status1), "Packet processing failed"
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1])
let (address2, delay2, processedPacket2, status2) = processSphinxPacket(processedPacket1, privateKeys[1], tm)
assert status2 == Success, "Processing status should be Success"
assert processedPacket2.len == packetSize, "Packet size be exactly " & $packetSize & " bytes"
assert not ifExit(address2, delay2, processedPacket2, status2), "Packet processing failed"
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2])
let (address3, delay3, processedPacket3, status3) = processSphinxPacket(processedPacket2, privateKeys[2], tm)
assert status3 == Success, "Processing status should be Success"
assert ifExit(address3, delay3, processedPacket3, status3), "Packet processing failed"

View File

@@ -1,42 +1,44 @@
import unittest, ../src/tag_manager, ../src/curve25519, tables
import unittest, ../src/tag_manager, ../src/curve25519
suite "tag_manager_tests":
# Setup to initialize the tag manager before running tests
initTagManager()
var tm: TagManager
test "add_and_check_tag":
let tag = generateRandomFieldElement()
addTag(tag)
check isTagSeen(tag)
let nonexistentTag = generateRandomFieldElement()
check not isTagSeen(nonexistentTag)
setup:
tm = initTagManager()
test "remove_tag":
let tag = generateRandomFieldElement()
addTag(tag)
check isTagSeen(tag)
removeTag(tag)
check not isTagSeen(tag)
teardown:
clearTags(tm)
test "check_tag_presence":
let tag = generateRandomFieldElement()
check not isTagSeen(tag)
addTag(tag)
check isTagSeen(tag)
removeTag(tag)
check not isTagSeen(tag)
test "add_and_check_tag":
let tag = generateRandomFieldElement()
addTag(tm, tag)
check isTagSeen(tm, tag)
let nonexistentTag = generateRandomFieldElement()
check not isTagSeen(tm, nonexistentTag)
test "handle_multiple_tags":
let tag1 = generateRandomFieldElement()
let tag2 = generateRandomFieldElement()
addTag(tag1)
addTag(tag2)
check isTagSeen(tag1)
check isTagSeen(tag2)
removeTag(tag1)
removeTag(tag2)
check not isTagSeen(tag1)
check not isTagSeen(tag2)
test "remove_tag":
let tag = generateRandomFieldElement()
addTag(tm, tag)
check isTagSeen(tm, tag)
removeTag(tm, tag)
check not isTagSeen(tm, tag)
# Teardown to clean up after running tests
clear(seenTags)
test "check_tag_presence":
let tag = generateRandomFieldElement()
check not isTagSeen(tm, tag)
addTag(tm, tag)
check isTagSeen(tm, tag)
removeTag(tm, tag)
check not isTagSeen(tm, tag)
test "handle_multiple_tags":
let tag1 = generateRandomFieldElement()
let tag2 = generateRandomFieldElement()
addTag(tm, tag1)
addTag(tm, tag2)
check isTagSeen(tm, tag1)
check isTagSeen(tm, tag2)
removeTag(tm, tag1)
removeTag(tm, tag2)
check not isTagSeen(tm, tag1)
check not isTagSeen(tm, tag2)