From 2cb3cba705230e147dc2780ff7ff52d6c26feb09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?rich=CE=9Brd?= Date: Wed, 3 Sep 2025 14:40:33 -0400 Subject: [PATCH] feat: add create, use and process SURB functions; split computeBetaGammaDelta (#68) --- mix/exit_connection.nim | 42 +++++- mix/mix_protocol.nim | 3 + mix/sphinx.nim | 148 +++++++++++++++++----- tests/test_sphinx.nim | 274 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 430 insertions(+), 37 deletions(-) diff --git a/mix/exit_connection.nim b/mix/exit_connection.nim index d66b949..3e6512b 100644 --- a/mix/exit_connection.nim +++ b/mix/exit_connection.nim @@ -1,5 +1,6 @@ -import hashes, chronos, libp2p/varint +import hashes, chronos, libp2p/varint, stew/byteutils import libp2p/stream/connection +from fragmentation import dataSize type MixExitConnection* = ref object of Connection message: seq[byte] @@ -113,15 +114,46 @@ method readLp*( self.message = self.message[int(length) .. ^1] return result +method write*( + self: MixExitConnection, msg: seq[byte] +): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = + # TODO: dial back + discard + +proc write*( + self: MixExitConnection, msg: string +): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = + self.write(msg.toBytes()) + method writeLp*( self: MixExitConnection, msg: openArray[byte] -): Future[void] {.async: (raises: [CancelledError, LPStreamError]), public.} = - raise newException(LPStreamError, "writeLp not implemented for MixExitConnection") +): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = + if msg.len() > dataSize: + let fut = newFuture[void]() + fut.fail( + newException(LPStreamError, "exceeds max msg size of " & $dataSize & " bytes") + ) + return fut + + var + vbytes: seq[byte] = @[] + value = msg.len().uint64 + + while value >= 128: + vbytes.add(byte((value and 127) or 128)) + value = value shr 7 + vbytes.add(byte(value)) + + var buf = newSeqUninitialized[byte](msg.len() + vbytes.len) + buf[0 ..< vbytes.len] = vbytes.toOpenArray(0, vbytes.len - 1) + buf[vbytes.len ..< buf.len] = msg + + # TODO: dial back method writeLp*( self: MixExitConnection, msg: string -): Future[void] {.async: (raises: [CancelledError, LPStreamError]), public.} = - raise newException(LPStreamError, "writeLp not implemented for MixExitConnection") +): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = + self.writeLp(msg.toOpenArrayByte(0, msg.high)) func shortLog*(self: MixExitConnection): string {.raises: [].} = "MixExitConnection" diff --git a/mix/mix_protocol.nim b/mix/mix_protocol.nim index 91e98f7..547fb2e 100644 --- a/mix/mix_protocol.nim +++ b/mix/mix_protocol.nim @@ -163,6 +163,9 @@ proc handleMixNodeConnection( except CatchableError as e: error "Failed to close outgoing stream: ", err = e.msg mix_messages_error.inc(labelValues = ["Intermediate", "DAIL_FAILED"]) + of Reply: + # TODO: implement + discard of Duplicate: mix_messages_error.inc(labelValues = ["Intermediate/Exit", "DUPLICATE"]) discard diff --git a/mix/sphinx.nim b/mix/sphinx.nim index 1062d5f..11e0792 100644 --- a/mix/sphinx.nim +++ b/mix/sphinx.nim @@ -1,10 +1,12 @@ import results, sequtils +import nimcrypto/sysrand import ./[config, crypto, curve25519, serialization, tag_manager] # Define possible outcomes of processing a Sphinx packet type ProcessingStatus* = enum Exit # Packet processed successfully at exit Intermediate # Packet processed successfully at intermediate node + Reply # Reply received at entry node for a message succesfuly processed at exit node Duplicate # Packet was discarded due to duplicate tag InvalidMAC # Packet was discarded due to MAC verification failure @@ -13,6 +15,9 @@ type ProcessingStatus* = enum # Function to compute alphas, shared secrets, and blinders +# Compute alpha, an ephemeral public value. Each mix node uses its private key and +# alpha to derive a shared session key for that hop. +# This session key is used to decrypt and process one layer of the packet. proc computeAlpha( publicKeys: openArray[FieldElement] ): Result[(seq[byte], seq[seq[byte]]), string] = @@ -109,34 +114,28 @@ proc generateRandomDelay(): seq[byte] = const paddingLength = (((t + 1) * (r - L)) + 2) * k -# Function to compute betas, gammas, and deltas -proc computeBetaGammaDelta( - s: seq[seq[byte]], - hop: openArray[Hop], - msg: Message, - delay: openArray[seq[byte]], - destHop: Hop, -): Result[(seq[byte], seq[byte], seq[byte]), string] = # TODO: name tuples +# Function to compute: +# Beta: The nested encrypted routing information. It encodes the next hop address, the forwarding delay, integrity check Gamma for the next hop, and the Beta for subsequent hops. +# Gamma: A message authentication code computed over Beta using the session key derived from Alpha. It ensures header integrity at each hop. +proc computeBetaGamma( + s: seq[seq[byte]], hop: openArray[Hop], delay: openArray[seq[byte]], destHop: Hop +): Result[(seq[byte], seq[byte]), string] = # TODO: name tuples let sLen = s.len var beta: seq[byte] gamma: seq[byte] - delta: seq[byte] # Compute filler strings let filler = computeFillerStrings(s).valueOr: return err("Error in filler generation: " & error) for i in countdown(sLen - 1, 0): - # Derive AES keys, MAC key, and IVs + # Derive AES key, MAC key, and IV let beta_aes_key = kdf(deriveKeyMaterial("aes_key", s[i])) mac_key = kdf(deriveKeyMaterial("mac_key", s[i])) beta_iv = kdf(deriveKeyMaterial("iv", s[i])) - delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", s[i])) - delta_iv = kdf(deriveKeyMaterial("delta_iv", s[i])) - # Compute Beta and Gamma if i == sLen - 1: let destBytes = ?destHop.serialize() @@ -145,12 +144,6 @@ proc computeBetaGammaDelta( let aes = aes_ctr(beta_aes_key, beta_iv, destPadding).valueOr: return err("Error in aes: " & error) beta = aes & filler - - let serializedMsg = msg.serialize().valueOr: - return err("Message serialization error: " & error) - - delta = aes_ctr(delta_aes_key, delta_iv, serializedMsg).valueOr: - return err("Error in aes: " & error) else: let routingInfo = RoutingInfo.init( hop[i + 1], delay[i], gamma, beta[0 .. (((r * (t + 1)) - t) * k) - 1] @@ -162,12 +155,100 @@ proc computeBetaGammaDelta( beta = aes_ctr(beta_aes_key, beta_iv, serializedRoutingInfo).valueOr: return err("Error in aes: " & error) + gamma = toSeq(hmac(mac_key, beta)) + + return ok((beta, gamma)) + +# Function to compute deltas +proc computeDelta(s: seq[seq[byte]], msg: Message): Result[seq[byte], string] = + let sLen = s.len + var delta: seq[byte] + + for i in countdown(sLen - 1, 0): + # Derive AES key and IV + let + delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", s[i])) + delta_iv = kdf(deriveKeyMaterial("delta_iv", s[i])) + + # Compute Delta + if i == sLen - 1: + let serializedMsg = msg.serialize().valueOr: + return err("Message serialization error: " & error) + + delta = aes_ctr(delta_aes_key, delta_iv, serializedMsg).valueOr: + return err("Error in aes: " & error) + else: delta = aes_ctr(delta_aes_key, delta_iv, delta).valueOr: return err("Error in aes: " & error) - gamma = toSeq(hmac(mac_key, beta)) + return ok(delta) - return ok((beta, gamma, delta)) +proc createSURB*( + publicKeys: openArray[FieldElement], + delay: openArray[seq[byte]], + hop: openArray[Hop], + destHop: Hop, +): Result[(Hop, Header, seq[seq[byte]], seq[byte]), string] = + # Compute alpha and shared secrets + let (alpha_0, s) = computeAlpha(publicKeys).valueOr: + return err("Error in alpha generation: " & error) + + # Compute beta and gamma + let (beta_0, gamma_0) = computeBetaGamma(s, hop, delay, destHop).valueOr: + return err("Error in beta and gamma generation: " & error) + + # Generate key + var key = newSeqUninitialized[byte](k) + discard randomBytes(key) + + return ok((hop[0], Header.init(alpha_0, beta_0, gamma_0), s, key)) + +proc useSURB*(header: Header, key: seq[byte], msg: Message): Result[seq[byte], string] = + # Derive AES key and IV + let + delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", key)) + delta_iv = kdf(deriveKeyMaterial("delta_iv", key)) + + # Compute Delta + let serializedMsg = msg.serialize.valueOr: + return err("Message serialization error: " & error) + + let delta = aes_ctr(delta_aes_key, delta_iv, serializedMsg).valueOr: + return err("Error in aes: " & error) + + let serialized = SphinxPacket.init(header, delta).serialize().valueOr: + return err("Sphinx packet serialization error: " & error) + + return ok(serialized) + +proc processReply*( + key: seq[byte], s: seq[seq[byte]], delta_prime: seq[byte] +): Result[seq[byte], string] = + var delta = delta_prime[0 ..^ 1] + + for i in 0 .. s.len: + var key_prime: seq[byte] + + if i == 0: + key_prime = key + else: + key_prime = s[i - 1] + + # Derive AES key and IV + let + delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", key_prime)) + delta_iv = kdf(deriveKeyMaterial("delta_iv", key_prime)) + + let deltaRes = aes_ctr(delta_aes_key, delta_iv, delta) + if deltaRes.isErr: + return err("Error in aes: " & deltaRes.error) + delta = deltaRes.get() + + let deserializeMsg = Message.deserialize(delta).valueOr: + return err("Message deserialization error: " & error) + let content = getContent(deserializeMsg) + + return ok(content) proc wrapInSphinxPacket*( msg: Message, @@ -176,13 +257,17 @@ proc wrapInSphinxPacket*( hop: openArray[Hop], destHop: Hop, ): Result[seq[byte], string] = - # Compute alphas and shared secrets + # Compute alpha and shared secrets let (alpha_0, s) = computeAlpha(publicKeys).valueOr: return err("Error in alpha generation: " & error) - # Compute betas, gammas, and deltas - let (beta_0, gamma_0, delta_0) = computeBetaGammaDelta(s, hop, msg, delay, destHop).valueOr: - return err("Error in beta, gamma, and delta generation: " & error) + # Compute beta and gamma + let (beta_0, gamma_0) = computeBetaGamma(s, hop, delay, destHop).valueOr: + return err("Error in beta and gamma generation: " & error) + + # Compute delta + let delta_0 = computeDelta(s, msg).valueOr: + return err("Error in delta generation: " & error) # Serialize sphinx packet let sphinxPacket = SphinxPacket.init(Header.init(alpha_0, beta_0, gamma_0), delta_0) @@ -243,12 +328,17 @@ proc processSphinxPacket*( zeroPadding = newSeq[byte](paddingLength) if B[(t * k) .. (t * k) + paddingLength - 1] == zeroPadding: - let msg = Message.deserialize(delta_prime).valueOr: - return err("Message deserialization error: " & error) - let content = msg.getContent() let hop = Hop.deserialize(B[0 .. addrSize - 1]).valueOr: return err(error) - return ok((hop, B[addrSize .. ((t * k) - 1)], content[0 .. messageSize - 1], Exit)) + + if delta_prime[0 .. (k - 1)] == newSeq[byte](k): + let msg = Message.deserialize(delta_prime).valueOr: + return err("Message deserialization error: " & error) + let content = msg.getContent() + return + ok((hop, B[addrSize .. ((t * k) - 1)], content[0 .. messageSize - 1], Exit)) + else: + return ok((hop, B[addrSize .. ((t * k) - 1)], delta_prime, Reply)) else: # Extract routing information from B let routingInfo = RoutingInfo.deserialize(B).valueOr: diff --git a/tests/test_sphinx.nim b/tests/test_sphinx.nim index 2b80a82..a1658ee 100644 --- a/tests/test_sphinx.nim +++ b/tests/test_sphinx.nim @@ -49,7 +49,6 @@ proc createDummyData(): ( message = Message.init(newSeq[byte](messageSize)) dest = Hop.init(newSeq[byte](addrSize)) - return (message, privateKeys, publicKeys, delay, hops, dest) # Unit tests for sphinx.nim @@ -134,8 +133,7 @@ suite "Sphinx Tests": error "Processing status should be Exit", status3 fail() - let processedMessage = Message.init(processedPacket3) - if processedMessage != message: + if processedPacket3 != message.getContent(): error "Packet processing failed" fail() @@ -304,3 +302,273 @@ suite "Sphinx Tests": if processedPacket3 != paddedMessage: error "Packet processing failed" fail() + + test "create_and_use_surb": + let (message, privateKeys, publicKeys, delay, hops, dest) = createDummyData() + + let surbRes = createSURB(publicKeys, delay, hops, dest) + if surbRes.isErr: + error "Create SURB error", err = surbRes.error + let (hop, header, s, key) = surbRes.get() + + let packetBytesRes = useSURB(header, key, message) + if packetBytesRes.isErr: + error "Use SURB error", err = packetBytesRes.error + let packetBytes = packetBytesRes.get() + + if packetBytes.len != packetSize: + error "Packet length is not valid", + pkt_len = $(packetBytes.len), expected_len = $packetSize + fail() + + let packetRes = SphinxPacket.deserialize(packetBytes) + if packetRes.isErr: + error "Sphinx wrap error", err = packetRes.error + fail() + let packet = packetRes.get() + + let res1 = processSphinxPacket(packet, privateKeys[0], tm) + if res1.isErr: + error "Error in Sphinx processing", err = res1.error + fail() + let (address1, delay1, processedPacket1Bytes, status1) = res1.get() + + if status1 != Intermediate: + error "Processing status should be Intermediate" + fail() + + if processedPacket1Bytes.len != packetSize: + error "Packet length is not valid", + pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize + fail() + + let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes) + if processedPacket1Res.isErr: + error "Sphinx wrap error", err = processedPacket1Res.error + fail() + let processedPacket1 = processedPacket1Res.get() + + let res2 = processSphinxPacket(processedPacket1, privateKeys[1], tm) + if res2.isErr: + error "Error in Sphinx processing", err = res2.error + fail() + let (address2, delay2, processedPacket2Bytes, status2) = res2.get() + + if status2 != Intermediate: + error "Processing status should be Success" + fail() + + if processedPacket2Bytes.len != packetSize: + error "Packet length is not valid", + pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize + fail() + + let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes) + if processedPacket2Res.isErr: + error "Sphinx wrap error", err = processedPacket2Res.error + fail() + let processedPacket2 = processedPacket2Res.get() + + let res3 = processSphinxPacket(processedPacket2, privateKeys[2], tm) + if res3.isErr: + error "Error in Sphinx processing", err = res3.error + fail() + let (address3, delay3, processedPacket3, status3) = res3.get() + + if status3 != Reply: + error "Processing status should be Reply" + fail() + + let msgRes = processReply(key, s, processedPacket3) + if msgRes.isErr: + error "Reply processing failed", err = msgRes.error + let msg = msgRes.get() + + if msg != message.getContent(): + error "Message tampered" + fail() + + test "create_surb_empty_public_keys": + let (message, _, _, delay, _, dest) = createDummyData() + + let surbRes = createSURB(@[], delay, @[], dest) + if surbRes.isOk: + error "Expected create SURB error when public keys are empty, but got success" + fail() + + test "surb_sphinx_process_invalid_mac": + let (message, privateKeys, publicKeys, delay, hops, dest) = createDummyData() + + let surbRes = createSURB(publicKeys, delay, hops, dest) + if surbRes.isErr: + error "Create SURB error", err = surbRes.error + let (hop, header, s, key) = surbRes.get() + + let packetRes = useSURB(header, key, message) + if packetRes.isErr: + error "Use SURB error", err = packetRes.error + let packet = packetRes.get() + + if packet.len != packetSize: + error "Packet length is not valid", + pkt_len = $(packet.len), expected_len = $packetSize + fail() + + # Corrupt the MAC for testing + var tamperedPacketBytes = packet + tamperedPacketBytes[0] = packet[0] xor 0x01 + + let tamperedPacketRes = SphinxPacket.deserialize(tamperedPacketBytes) + if tamperedPacketRes.isErr: + error "Sphinx wrap error", err = tamperedPacketRes.error + fail() + let tamperedPacket = tamperedPacketRes.get() + + let res = processSphinxPacket(tamperedPacket, privateKeys[0], tm) + if res.isErr: + error "Error in Sphinx processing", err = res.error + fail() + let (_, _, _, status) = res.get() + + if status != InvalidMAC: + error "Processing status should be InvalidMAC" + fail() + + test "surb_sphinx_process_duplicate_tag": + let (message, privateKeys, publicKeys, delay, hops, dest) = createDummyData() + + let surbRes = createSURB(publicKeys, delay, hops, dest) + if surbRes.isErr: + error "Create SURB error", err = surbRes.error + let (hop, header, s, key) = surbRes.get() + + let packetBytesRes = useSURB(header, key, message) + if packetBytesRes.isErr: + error "Use SURB error", err = packetBytesRes.error + let packetBytes = packetBytesRes.get() + + if packetBytes.len != packetSize: + error "Packet length is not valid", + pkt_len = $(packetBytes.len), expected_len = $packetSize + fail() + + let packetRes = SphinxPacket.deserialize(packetBytes) + if packetRes.isErr: + error "Sphinx wrap error", err = packetRes.error + fail() + let packet = packetRes.get() + + # Process the packet twice to test duplicate tag handling + let res1 = processSphinxPacket(packet, privateKeys[0], tm) + if res1.isErr: + error "Error in Sphinx processing", err = res1.error + fail() + let (_, _, _, status1) = res1.get() + + if status1 != Intermediate: + error "Processing status should be Success" + fail() + + let res2 = processSphinxPacket(packet, privateKeys[0], tm) + if res2.isErr: + error "Error in Sphinx processing", err = res2.error + fail() + let (_, _, _, status2) = res2.get() + + if status2 != Duplicate: + error "Processing status should be Duplicate" + fail() + + test "create_and_use_surb_message_sizes": + let messageSizes = @[32, 64, 128, 256, 512] + for size in messageSizes: + let (_, privateKeys, publicKeys, delay, hops, dest) = createDummyData() + var message = newSeq[byte](size) + randomize() + for i in 0 ..< size: + message[i] = byte(rand(256)) + let paddedMessage = padMessage(message, messageSize) + + let surbRes = createSURB(publicKeys, delay, hops, dest) + if surbRes.isErr: + error "Create SURB error", err = surbRes.error + let (hop, header, s, key) = surbRes.get() + + let packetBytesRes = useSURB(header, key, Message.init(paddedMessage)) + if packetBytesRes.isErr: + error "Use SURB error", err = packetBytesRes.error + let packetBytes = packetBytesRes.get() + + if packetBytes.len != packetSize: + error "Packet length is not valid", + pkt_len = $(packetBytes.len), + expected_len = $packetSize, + msg_len = $messageSize + fail() + + let packetRes = SphinxPacket.deserialize(packetBytes) + if packetRes.isErr: + error "Sphinx wrap error", err = packetBytesRes.error + fail() + let packet = packetRes.get() + + let res1 = processSphinxPacket(packet, privateKeys[0], tm) + if res1.isErr: + error "Error in Sphinx processing", err = res1.error + fail() + let (address1, delay1, processedPacket1Bytes, status1) = res1.get() + + if status1 != Intermediate: + error "Processing status should be Success" + fail() + + if processedPacket1Bytes.len != packetSize: + error "Packet length is not valid", + pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize + fail() + + let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes) + if processedPacket1Res.isErr: + error "Sphinx wrap error", err = processedPacket1Res.error + fail() + let processedPacket1 = processedPacket1Res.get() + + let res2 = processSphinxPacket(processedPacket1, privateKeys[1], tm) + if res2.isErr: + error "Error in Sphinx processing", err = res2.error + fail() + let (address2, delay2, processedPacket2Bytes, status2) = res2.get() + + if status2 != Intermediate: + error "Processing status should be Success" + fail() + + if processedPacket2Bytes.len != packetSize: + error "Packet length is not valid", + pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize + fail() + + let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes) + if processedPacket1Res.isErr: + error "Sphinx wrap error", err = processedPacket2Res.error + fail() + let processedPacket2 = processedPacket2Res.get() + + let res3 = processSphinxPacket(processedPacket2, privateKeys[2], tm) + if res3.isErr: + error "Error in Sphinx processing", err = res3.error + fail() + let (address3, delay3, processedPacket3, status3) = res3.get() + + if status3 != Reply: + error "Processing status should be Reply" + fail() + + let msgRes = processReply(key, s, processedPacket3) + if msgRes.isErr: + error "Reply processing failed", err = msgRes.error + let msg = msgRes.get() + + if paddedMessage != msg: + error "Message tampered" + fail()