feat(mix): SURBs and fragmentation (#1700)

This commit is contained in:
richΛrd
2025-09-19 15:56:26 -04:00
committed by GitHub
parent 37bae0986c
commit 70b7d61436
6 changed files with 481 additions and 53 deletions

View File

@@ -0,0 +1,95 @@
import ./[serialization, seqno_generator]
import results, stew/endians2
import ../../peerid
const PaddingLengthSize* = 2
const SeqNoSize* = 4
const DataSize* = MessageSize - PaddingLengthSize - SeqNoSize
# Unpadding and reassembling messages will be handled by the top-level applications.
# Although padding and splitting messages could also be managed at that level, we
# implement it here to clarify the sender's logic.
# This is crucial as the sender is responsible for wrapping messages in Sphinx packets.
type MessageChunk* = object
paddingLength: uint16
data: seq[byte]
seqNo: uint32
proc init*(
T: typedesc[MessageChunk], paddingLength: uint16, data: seq[byte], seqNo: uint32
): T =
T(paddingLength: paddingLength, data: data, seqNo: seqNo)
proc get*(msgChunk: MessageChunk): (uint16, seq[byte], uint32) =
(msgChunk.paddingLength, msgChunk.data, msgChunk.seqNo)
proc serialize*(msgChunk: MessageChunk): seq[byte] =
let
paddingBytes = msgChunk.paddingLength.toBytesBE()
seqNoBytes = msgChunk.seqNo.toBytesBE()
doAssert msgChunk.data.len == DataSize,
"Padded data must be exactly " & $DataSize & " bytes"
return @paddingBytes & msgChunk.data & @seqNoBytes
proc deserialize*(T: typedesc[MessageChunk], data: openArray[byte]): Result[T, string] =
if data.len != MessageSize:
return err("Data must be exactly " & $MessageSize & " bytes")
let
paddingLength = uint16.fromBytesBE(data[0 .. PaddingLengthSize - 1])
chunk = data[PaddingLengthSize .. (PaddingLengthSize + DataSize - 1)]
seqNo = uint32.fromBytesBE(data[PaddingLengthSize + DataSize ..^ 1])
ok(T(paddingLength: paddingLength, data: chunk, seqNo: seqNo))
proc ceilDiv*(a, b: int): int =
(a + b - 1) div b
proc addPadding*(messageBytes: seq[byte], seqNo: SeqNo): MessageChunk =
## Pads messages smaller than DataSize
let paddingLength = uint16(DataSize - messageBytes.len)
let paddedData =
if paddingLength > 0:
let paddingBytes = newSeq[byte](paddingLength)
paddingBytes & messageBytes
else:
messageBytes
MessageChunk(paddingLength: paddingLength, data: paddedData, seqNo: seqNo)
proc addPadding*(messageBytes: seq[byte], peerId: PeerId): MessageChunk =
## Pads messages smaller than DataSize
var seqNoGen = SeqNo.init(peerId)
seqNoGen.generate(messageBytes)
messageBytes.addPadding(seqNoGen)
proc removePadding*(msgChunk: MessageChunk): Result[seq[byte], string] =
let msgLength = len(msgChunk.data) - int(msgChunk.paddingLength)
if msgLength < 0:
return err("Invalid padding length")
ok(msgChunk.data[msgChunk.paddingLength ..^ 1])
proc padAndChunkMessage*(messageBytes: seq[byte], peerId: PeerId): seq[MessageChunk] =
var seqNoGen = SeqNo.init(peerId)
seqNoGen.generate(messageBytes)
var chunks: seq[MessageChunk] = @[]
# Split to chunks
let totalChunks = max(1, ceilDiv(messageBytes.len, DataSize))
# Ensure at least one chunk is generated
for i in 0 ..< totalChunks:
let
startIdx = i * DataSize
endIdx = min(startIdx + DataSize, messageBytes.len)
chunkData = messageBytes[startIdx .. endIdx - 1]
msgChunk = chunkData.addPadding(seqNoGen)
chunks.add(msgChunk)
seqNoGen.inc()
return chunks

View File

@@ -139,6 +139,23 @@ proc serialize*(info: RoutingInfo): seq[byte] =
return addrBytes & info.Delay & info.Gamma & info.Beta
proc readBytes(
data: openArray[byte], offset: var int, readSize: Opt[int] = Opt.none(int)
): Result[seq[byte], string] =
if data.len < offset:
return err("not enough data")
readSize.withValue(size):
if data.len < offset + size:
return err("not enough data")
let slice = data[offset ..< offset + size]
offset += size
return ok(slice)
let slice = data[offset .. ^1]
offset = data.len
return ok(slice)
proc deserialize*(T: typedesc[RoutingInfo], data: openArray[byte]): Result[T, string] =
if len(data) != BetaSize + ((t + 1) * k):
return err("Data must be exactly " & $(BetaSize + ((t + 1) * k)) & " bytes")
@@ -146,13 +163,13 @@ proc deserialize*(T: typedesc[RoutingInfo], data: openArray[byte]): Result[T, st
let hop = Hop.deserialize(data[0 .. AddrSize - 1]).valueOr:
return err("Deserialize hop error: " & error)
var offset: int = AddrSize
return ok(
RoutingInfo(
Addr: hop,
Delay: data[AddrSize .. (AddrSize + DelaySize - 1)],
Gamma: data[(AddrSize + DelaySize) .. (AddrSize + DelaySize + GammaSize - 1)],
Beta:
data[(AddrSize + DelaySize + GammaSize) .. (((r * (t + 1)) + t + 2) * k) - 1],
Delay: ?data.readBytes(offset, Opt.some(DelaySize)),
Gamma: ?data.readBytes(offset, Opt.some(GammaSize)),
Beta: ?data.readBytes(offset, Opt.some(BetaSize)),
)
)
@@ -183,7 +200,7 @@ type
Key* = seq[byte]
I* = array[SurbIdLen, byte]
SURBIdentifier* = array[SurbIdLen, byte]
SURB* = object
hop*: Hop
@@ -201,23 +218,6 @@ proc serializeMessageWithSURBs*(
surbs.mapIt(it.hop.serialize() & it.header.serialize() & it.key).concat()
ok(byte(surbs.len) & surbBytes & msg)
proc readBytes(
data: seq[byte], offset: var int, readSize: Opt[int] = Opt.none(int)
): Result[seq[byte], string] =
if data.len < offset:
return err("not enough data")
readSize.withValue(size):
if data.len < offset + size:
return err("not enough data")
let slice = data[offset ..< offset + size]
offset += size
return ok(slice)
let slice = data[offset .. ^1]
offset = data.len
return ok(slice)
proc extractSURBs*(msg: seq[byte]): Result[(seq[SURB], seq[byte]), string] =
var offset = 0
let surbsLenBytes = ?readBytes(msg, offset, Opt.some(1))

View File

@@ -1,11 +1,16 @@
import results, sequtils, stew/endians2
import ./[crypto, curve25519, serialization, tag_manager]
import ../../crypto/crypto
import ../../utils/sequninit
const PaddingLength = (((t + 1) * (r - L)) + 1) * k
type ProcessingStatus* = enum
Exit # Packet processed successfully at exit
Intermediate # Packet processed successfully at intermediate node
Duplicate # Packet was discarded due to duplicate tag
InvalidMAC # Packet was discarded due to MAC verification failure
Exit
Intermediate
Reply
Duplicate
InvalidMAC
proc computeAlpha(
publicKeys: openArray[FieldElement]
@@ -79,16 +84,12 @@ proc computeFillerStrings(s: seq[seq[byte]]): Result[seq[byte], string] =
return ok(filler)
const paddingLength = (((t + 1) * (r - L)) + 1) * k
# Function to compute:
proc computeBetaGamma(
s: seq[seq[byte]],
hop: openArray[Hop],
delay: openArray[seq[byte]],
destHop: Hop,
id: I,
id: SURBIdentifier,
): Result[tuple[beta: seq[byte], gamma: seq[byte]], string] =
## Calculates the following elements:
## - Beta: The nested encrypted routing information. It encodes the next hop address, the forwarding delay, integrity check Gamma for the next hop, and the Beta for subsequent hops.
@@ -112,7 +113,7 @@ proc computeBetaGamma(
# Compute Beta and Gamma
if i == sLen - 1:
let destBytes = destHop.serialize()
let destPadding = destBytes & delay[i] & @id & newSeq[byte](paddingLength)
let destPadding = destBytes & delay[i] & @id & newSeq[byte](PaddingLength)
let aes = aes_ctr(beta_aes_key, beta_iv, destPadding)
@@ -149,6 +150,70 @@ proc computeDelta(s: seq[seq[byte]], msg: Message): Result[seq[byte], string] =
return ok(delta)
proc createSURB*(
publicKeys: openArray[FieldElement],
delay: openArray[seq[byte]],
hops: openArray[Hop],
id: SURBIdentifier,
rng: ref HmacDrbgContext = newRng(),
): Result[SURB, string] =
if id == default(SURBIdentifier):
return err("id should be initialized")
# Compute alpha and shared secrets
let (alpha_0, s) = computeAlpha(publicKeys).valueOr:
return err("Error in alpha generation: " & error)
# Compute beta and gamma
let (beta_0, gamma_0) = computeBetaGamma(s, hops, delay, Hop(), id).valueOr:
return err("Error in beta and gamma generation: " & error)
# Generate key
var key = newSeqUninit[byte](k)
rng[].generate(key)
return ok(
SURB(
hop: hops[0],
header: Header.init(alpha_0, beta_0, gamma_0),
secret: Opt.some(s),
key: key,
)
)
proc useSURB*(surb: SURB, msg: Message): SphinxPacket =
# Derive AES key and IV
let
delta_aes_key = deriveKeyMaterial("delta_aes_key", surb.key).kdf()
delta_iv = deriveKeyMaterial("delta_iv", surb.key).kdf()
# Compute Delta
let serializedMsg = msg.serialize()
let delta = aes_ctr(delta_aes_key, delta_iv, serializedMsg)
return SphinxPacket.init(surb.header, delta)
proc processReply*(
key: seq[byte], s: seq[seq[byte]], delta_prime: seq[byte]
): Result[seq[byte], string] =
var delta = delta_prime[0 ..^ 1]
var key_prime = key
for i in 0 .. s.len:
if i != 0:
key_prime = s[i - 1]
let
delta_aes_key = deriveKeyMaterial("delta_aes_key", key_prime).kdf()
delta_iv = deriveKeyMaterial("delta_iv", key_prime).kdf()
delta = aes_ctr(delta_aes_key, delta_iv, delta)
let deserializeMsg = Message.deserialize(delta).valueOr:
return err("Message deserialization error: " & error)
return ok(deserializeMsg)
proc wrapInSphinxPacket*(
msg: Message,
publicKeys: openArray[FieldElement],
@@ -161,7 +226,9 @@ proc wrapInSphinxPacket*(
return err("Error in alpha generation: " & error)
# Compute beta and gamma
let (beta_0, gamma_0) = computeBetaGamma(s, hop, delay, destHop, default(I)).valueOr:
let (beta_0, gamma_0) = computeBetaGamma(
s, hop, delay, destHop, default(SURBIdentifier)
).valueOr:
return err("Error in beta and gamma generation: " & error)
# Compute delta
@@ -184,9 +251,27 @@ type ProcessedSphinxPacket* = object
nextHop*: Hop
delayMs*: int
serializedSphinxPacket*: seq[byte]
of ProcessingStatus.Reply:
id*: SURBIdentifier
delta_prime*: seq[byte]
else:
discard
proc isZeros(data: seq[byte], startIdx: int, endIdx: int): bool =
doAssert 0 <= startIdx and endIdx < data.len and startIdx <= endIdx
for i in startIdx .. endIdx:
if data[i] != 0:
return false
return true
template extractSurbId(data: seq[byte]): SURBIdentifier =
const startIndex = t * k
const endIndex = startIndex + SurbIdLen - 1
doAssert data.len > startIndex and endIndex < data.len
var id: SURBIdentifier
copyMem(addr id[0], addr data[startIndex], SurbIdLen)
id
proc processSphinxPacket*(
sphinxPacket: SphinxPacket, privateKey: FieldElement, tm: var TagManager
): Result[ProcessedSphinxPacket, string] =
@@ -228,19 +313,16 @@ proc processSphinxPacket*(
let delta_prime = aes_ctr(delta_aes_key, delta_iv, payload)
# Compute B
var zeroPadding = newSeq[byte]((t + 1) * k)
let zeroPadding = newSeq[byte]((t + 1) * k)
let B = aes_ctr(beta_aes_key, beta_iv, beta & zeroPadding)
# Check if B has the required prefix for the original message
zeroPadding = newSeq[byte](paddingLength)
if B[((t + 1) * k) .. ((t + 1) * k) + paddingLength - 1] == zeroPadding:
if B.isZeros((t + 1) * k, ((t + 1) * k) + PaddingLength - 1):
let hop = Hop.deserialize(B[0 .. AddrSize - 1]).valueOr:
return err(error)
if B[AddrSize .. ((t + 1) * k) - 1] == newSeq[byte](k + 2):
if delta_prime[0 .. (k - 1)] == newSeq[byte](k):
if B.isZeros(AddrSize, ((t + 1) * k) - 1):
if delta_prime.isZeros(0, k - 1):
let msg = Message.deserialize(delta_prime).valueOr:
return err("Message deserialization error: " & error)
return ok(
@@ -250,9 +332,12 @@ proc processSphinxPacket*(
)
else:
return err("delta_prime should be all zeros")
elif B[0 .. (t * k) - 1] == newSeq[byte](t * k):
# TODO: handle REPLY case
discard
elif B.isZeros(0, (t * k) - 1):
return ok(
ProcessedSphinxPacket(
status: Reply, id: B.extractSurbId(), delta_prime: delta_prime
)
)
else:
# Extract routing information from B
let routingInfo = RoutingInfo.deserialize(B).valueOr:

View File

@@ -0,0 +1,101 @@
{.used.}
import results, unittest
import ../../libp2p/peerid
import ../../libp2p/protocols/mix/[serialization, fragmentation]
suite "Fragmentation":
let peerId =
PeerId.init("16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC").get()
test "serialize and deserialize message chunk":
let
message = newSeq[byte](DataSize)
chunks = padAndChunkMessage(message, peerId)
serialized = chunks[0].serialize()
deserialized =
MessageChunk.deserialize(serialized).expect("Deserialization error")
check chunks[0] == deserialized
test "pad and unpad small message":
let
message = cast[seq[byte]]("Hello, World!")
messageBytesLen = len(message)
paddedMsg = addPadding(message, peerId)
unpaddedMessage = removePadding(paddedMsg).expect("Unpad error")
let (paddingLength, data, _) = paddedMsg.get()
check:
paddingLength == uint16(DataSize - messageBytesLen)
data.len == DataSize
unpaddedMessage.len == messageBytesLen
test "pad and chunk large message":
let
message = newSeq[byte](MessageSize * 2 + (MessageSize - 1))
messageBytesLen = len(message)
chunks = padAndChunkMessage(message, peerId)
totalChunks = max(1, ceilDiv(messageBytesLen, DataSize))
check chunks.len == totalChunks
for i in 0 ..< totalChunks:
let (paddingLength, data, _) = chunks[i].get()
if i != totalChunks - 1:
check paddingLength == 0
else:
let chunkSize = messageBytesLen mod DataSize
check paddingLength == uint16(DataSize - chunkSize)
check data.len == DataSize
test "chunk sequence numbers are consecutive":
let
message = newSeq[byte](MessageSize * 3)
messageBytesLen = len(message)
chunks = padAndChunkMessage(message, peerId)
totalChunks = max(1, ceilDiv(messageBytesLen, DataSize))
check chunks.len == totalChunks
let (_, _, firstSeqNo) = chunks[0].get()
for i in 1 ..< totalChunks:
let (_, _, seqNo) = chunks[i].get()
check seqNo == firstSeqNo + uint32(i)
test "chunk data reconstructs original message":
let
message = cast[seq[byte]]("This is a test message that will be split into multiple chunks.")
chunks = padAndChunkMessage(message, peerId)
var reconstructed: seq[byte]
for chunk in chunks:
let (paddingLength, data, _) = chunk.get()
reconstructed.add(data[paddingLength.int ..^ 1])
check reconstructed == message
test "empty message handling":
let
message = cast[seq[byte]]("")
chunks = padAndChunkMessage(message, peerId)
check chunks.len == 1
let (paddingLength, _, _) = chunks[0].get()
check paddingLength == uint16(DataSize)
test "message size equal to chunk size":
let
message = newSeq[byte](DataSize)
chunks = padAndChunkMessage(message, peerId)
check chunks.len == 1
let (paddingLength, _, _) = chunks[0].get()
check paddingLength == 0

View File

@@ -1,12 +1,12 @@
{.used.}
import random, results, unittest
import random, results, unittest, chronicles
import ../../libp2p/crypto/crypto
import ../../libp2p/protocols/mix/[curve25519, serialization, sphinx, tag_manager]
import bearssl/rand
# Helper function to pad/truncate message
proc padMessage(message: openArray[byte], size: int): seq[byte] =
proc addPadding(message: openArray[byte], size: int): seq[byte] =
if message.len >= size:
return message[0 .. size - 1] # Truncate if larger
else:
@@ -39,8 +39,8 @@ proc createDummyData(): (
dest = Hop.init(newSeq[byte](AddrSize))
return (message, privateKeys, publicKeys, delay, hops, dest)
proc randomI(): I =
newRng()[].generate(I)
template randomI(): SURBIdentifier =
newRng()[].generate(SURBIdentifier)
# Unit tests for sphinx.nim
suite "Sphinx Tests":
@@ -52,7 +52,7 @@ suite "Sphinx Tests":
teardown:
clearTags(tm)
test "sphinx_wrap_and_process":
test "sphinx wrap and process":
let (message, privateKeys, publicKeys, delay, hops, dest) = createDummyData()
let packetBytes = wrapInSphinxPacket(message, publicKeys, delay, hops, dest).expect(
@@ -94,7 +94,7 @@ suite "Sphinx Tests":
processedSP3.status == Exit
processedSP3.messageChunk == message
test "sphinx_wrap_empty_public_keys":
test "sphinx wrap empty public keys":
let (message, _, _, delay, _, dest) = createDummyData()
check wrapInSphinxPacket(message, @[], delay, @[], dest).isErr
@@ -118,7 +118,7 @@ suite "Sphinx Tests":
check invalidMacPkt.status == InvalidMAC
test "sphinx_process_duplicate_tag":
test "sphinx process duplicate tag":
let (message, privateKeys, publicKeys, delay, hops, dest) = createDummyData()
let packetBytes = wrapInSphinxPacket(message, publicKeys, delay, hops, dest).expect(
@@ -140,7 +140,7 @@ suite "Sphinx Tests":
check processedSP2.status == Duplicate
test "sphinx_wrap_and_process_message_sizes":
test "sphinx wrap and process message sizes":
let MessageSizes = @[32, 64, 128, 256, 512]
for size in MessageSizes:
let (_, privateKeys, publicKeys, delay, hops, dest) = createDummyData()
@@ -148,7 +148,7 @@ suite "Sphinx Tests":
randomize()
for i in 0 ..< size:
message[i] = byte(rand(256))
let paddedMessage = padMessage(message, MessageSize)
let paddedMessage = addPadding(message, MessageSize)
let packetBytes = wrapInSphinxPacket(paddedMessage, publicKeys, delay, hops, dest)
.expect("Sphinx wrap error")
@@ -186,3 +186,150 @@ suite "Sphinx Tests":
check:
processedSP3.status == Exit
processedSP3.messageChunk == paddedMessage
test "create and use surb":
let (message, privateKeys, publicKeys, delay, hops, _) = createDummyData()
let surb =
createSURB(publicKeys, delay, hops, randomI()).expect("Create SURB error")
let packetBytes = useSURB(surb, message).serialize()
check packetBytes.len == PacketSize
let packet = SphinxPacket.deserialize(packetBytes).expect("Sphinx wrap error")
let processedSP1 =
processSphinxPacket(packet, privateKeys[0], tm).expect("Sphinx processing error")
check:
processedSP1.status == Intermediate
processedSP1.serializedSphinxPacket.len == PacketSize
let processedPacket1 = SphinxPacket
.deserialize(processedSP1.serializedSphinxPacket)
.expect("Sphinx wrap error")
let processedSP2 = processSphinxPacket(processedPacket1, privateKeys[1], tm).expect(
"Sphinx processing error"
)
check:
processedSP2.status == Intermediate
processedSP2.serializedSphinxPacket.len == PacketSize
let processedPacket2 = SphinxPacket
.deserialize(processedSP2.serializedSphinxPacket)
.expect("Sphinx wrap error")
let processedSP3 = processSphinxPacket(processedPacket2, privateKeys[2], tm).expect(
"Sphinx processing error"
)
check processedSP3.status == Reply
let msg = processReply(surb.key, surb.secret.get(), processedSP3.delta_prime).expect(
"Reply processing failed"
)
check msg == message
test "create surb empty public keys":
let (message, _, _, delay, _, _) = createDummyData()
check createSURB(@[], delay, @[], randomI()).isErr()
test "surb sphinx process invalid mac":
let (message, privateKeys, publicKeys, delay, hops, _) = createDummyData()
let surb =
createSURB(publicKeys, delay, hops, randomI()).expect("Create SURB error")
let packetBytes = useSURB(surb, message).serialize()
check packetBytes.len == PacketSize
# Corrupt the MAC for testing
var tamperedPacketBytes = packetBytes
tamperedPacketBytes[0] = packetBytes[0] xor 0x01
let tamperedPacket =
SphinxPacket.deserialize(tamperedPacketBytes).expect("Sphinx wrap error")
let processedSP1 = processSphinxPacket(tamperedPacket, privateKeys[0], tm).expect(
"Sphinx processing error"
)
check processedSP1.status == InvalidMAC
test "surb sphinx process duplicate tag":
let (message, privateKeys, publicKeys, delay, hops, _) = createDummyData()
let surb =
createSURB(publicKeys, delay, hops, randomI()).expect("Create SURB error")
let packetBytes = useSURB(surb, message).serialize()
check packetBytes.len == PacketSize
let packet = SphinxPacket.deserialize(packetBytes).expect("Sphinx wrap error")
# Process the packet twice to test duplicate tag handling
let processedSP1 =
processSphinxPacket(packet, privateKeys[0], tm).expect("Sphinx processing error")
check processedSP1.status == Intermediate
let processedSP2 =
processSphinxPacket(packet, privateKeys[0], tm).expect("Sphinx processing error")
check processedSP2.status == Duplicate
test "create and use surb message sizes":
let messageSizes = @[32, 64, 128, 256, 512]
for size in messageSizes:
let (_, privateKeys, publicKeys, delay, hops, _) = createDummyData()
var message = newSeq[byte](size)
randomize()
for i in 0 ..< size:
message[i] = byte(rand(256))
let paddedMessage = addPadding(message, MessageSize)
let surb =
createSURB(publicKeys, delay, hops, randomI()).expect("Create SURB error")
let packetBytes = useSURB(surb, Message(paddedMessage)).serialize()
check packetBytes.len == PacketSize
let packet = SphinxPacket.deserialize(packetBytes).expect("Sphinx wrap error")
let processedSP1 = processSphinxPacket(packet, privateKeys[0], tm).expect(
"Sphinx processing error"
)
check:
processedSP1.status == Intermediate
processedSP1.serializedSphinxPacket.len == PacketSize
let processedPacket1 = SphinxPacket
.deserialize(processedSP1.serializedSphinxPacket)
.expect("Sphinx wrap error")
let processedSP2 = processSphinxPacket(processedPacket1, privateKeys[1], tm)
.expect("Sphinx processing error")
check:
processedSP2.status == Intermediate
processedSP2.serializedSphinxPacket.len == PacketSize
let processedPacket2 = SphinxPacket
.deserialize(processedSP2.serializedSphinxPacket)
.expect("Sphinx wrap error")
let processedSP3 = processSphinxPacket(processedPacket2, privateKeys[2], tm)
.expect("Sphinx processing error")
check processedSP3.status == Reply
let msg = processReply(surb.key, surb.secret.get(), processedSP3.delta_prime)
.expect("Reply processing failed")
check paddedMessage == msg

View File

@@ -45,5 +45,5 @@ when defined(libp2p_autotls_support):
import
mix/[
testcrypto, testcurve25519, testtagmanager, testseqnogenerator, testserialization,
testmixmessage, testsphinx, testmultiaddr,
testmixmessage, testsphinx, testmultiaddr, testfragmentation,
]