From dcea429e0972582095c7713be74822e625619560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?rich=CE=9Brd?= Date: Fri, 5 Sep 2025 12:07:05 -0400 Subject: [PATCH] refactor: processSphinxPacket and process replies and return in conn (#78) --- examples/poc_resp_ping.nim | 2 +- mix/entry_connection.nim | 169 ++++++++++++++++++++++++++++++------- mix/exit_layer.nim | 2 +- mix/mix_protocol.nim | 114 ++++++++++++++++++------- mix/reply_connection.nim | 51 +++++------ mix/sphinx.nim | 50 ++++++++--- tests/test_sphinx.nim | 150 +++++++++++++++++--------------- 7 files changed, 371 insertions(+), 167 deletions(-) diff --git a/examples/poc_resp_ping.nim b/examples/poc_resp_ping.nim index 6858efc..f8b675c 100644 --- a/examples/poc_resp_ping.nim +++ b/examples/poc_resp_ping.nim @@ -116,7 +116,7 @@ proc mixnetSimulation() {.async: (raises: [Exception]).} = let response = await pingProto[senderIndex].ping(conn) - await sleepAsync(1.seconds) + info "PING RESPONSE", response deleteNodeInfoFolder() deletePubInfoFolder() diff --git a/mix/entry_connection.nim b/mix/entry_connection.nim index b27eaf6..f9c26ba 100644 --- a/mix/entry_connection.nim +++ b/mix/entry_connection.nim @@ -1,5 +1,6 @@ import hashes, chronos, stew/byteutils, results, chronicles import libp2p/stream/connection +import libp2p/varint import ./mix_protocol import ./config from fragmentation import dataSize @@ -44,27 +45,134 @@ type MixEntryConnection* = ref object of Connection mixDialer: MixDialer params: Opt[MixParameters] + incoming: AsyncQueue[seq[byte]] + incomingFut: Future[void] + replyReceivedFut: Future[void] + cached: seq[byte] + +method readOnce*( + s: MixEntryConnection, pbytes: pointer, nbytes: int +): Future[int] {.async: (raises: [CancelledError, LPStreamError]), public.} = + if s.isEof: + raise newLPStreamEOFError() + + try: + await s.replyReceivedFut + s.isEof = true + if s.cached.len == 0: + raise newLPStreamEOFError() + except CancelledError as exc: + raise exc + except LPStreamEOFError as exc: + raise exc + except CatchableError as exc: + raise (ref LPStreamError)(msg: "error in readOnce: " & exc.msg, parent: exc) + + let toRead = min(nbytes, s.cached.len) + copyMem(pbytes, addr s.cached[0], toRead) + s.cached = s.cached[toRead ..^ 1] + return toRead + method readExactly*( - self: MixEntryConnection, pbytes: pointer, nbytes: int + s: MixEntryConnection, pbytes: pointer, nbytes: int ): Future[void] {.async: (raises: [CancelledError, LPStreamError]), public.} = - await sleepAsync(10.minutes) # TODO: implement readExactly - raise - newException(LPStreamError, "readExactly not implemented for MixEntryConnection") + ## Waits for `nbytes` to be available, then read + ## them and return them + if s.atEof: + var ch: char + discard await s.readOnce(addr ch, 1) + raise newLPStreamEOFError() + + if nbytes == 0: + return + + logScope: + s + nbytes = nbytes + objName = s.objName + + var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) + var read = 0 + while read < nbytes and not (s.atEof()): + read += await s.readOnce(addr pbuffer[read], nbytes - read) + + if read == 0: + doAssert s.atEof() + trace "couldn't read all bytes, stream EOF", s, nbytes, read + # Re-readOnce to raise a more specific error than EOF + # Raise EOF if it doesn't raise anything(shouldn't happen) + discard await s.readOnce(addr pbuffer[read], nbytes - read) + + raise newLPStreamEOFError() + + if read < nbytes: + trace "couldn't read all bytes, incomplete data", s, nbytes, read + raise newLPStreamIncompleteError() method readLine*( - self: MixEntryConnection, limit = 0, sep = "\r\n" + s: MixEntryConnection, limit = 0, sep = "\r\n" ): Future[string] {.async: (raises: [CancelledError, LPStreamError]), public.} = - raise newException(LPStreamError, "readLine not implemented for MixEntryConnection") + ## Reads up to `limit` bytes are read, or a `sep` is found + # TODO replace with something that exploits buffering better + var lim = if limit <= 0: -1 else: limit + var state = 0 + + while true: + var ch: char + await readExactly(s, addr ch, 1) + + if sep[state] == ch: + inc(state) + if state == len(sep): + break + else: + state = 0 + if limit > 0: + let missing = min(state, lim - len(result) - 1) + result.add(sep[0 ..< missing]) + else: + result.add(sep[0 ..< state]) + + result.add(ch) + if len(result) == lim: + break method readVarint*( - self: MixEntryConnection + conn: MixEntryConnection ): Future[uint64] {.async: (raises: [CancelledError, LPStreamError]), public.} = - raise newException(LPStreamError, "readVarint not implemented for MixEntryConnection") + var buffer: array[10, byte] + + for i in 0 ..< len(buffer): + await conn.readExactly(addr buffer[i], 1) + + var + varint: uint64 + length: int + let res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint) + if res.isOk(): + return varint + if res.error() != VarintError.Incomplete: + break + if true: # can't end with a raise apparently + raise (ref InvalidVarintError)(msg: "Cannot parse varint") method readLp*( - self: MixEntryConnection, maxSize: int + s: MixEntryConnection, maxSize: int ): Future[seq[byte]] {.async: (raises: [CancelledError, LPStreamError]), public.} = - raise newException(LPStreamError, "readLp not implemented for MixEntryConnection") + ## read length prefixed msg, with the length encoded as a varint + let + length = await s.readVarint() + maxLen = uint64(if maxSize < 0: int.high else: maxSize) + + if length > maxLen: + raise (ref MaxSizeError)(msg: "Message exceeds maximum length") + + if length == 0: + return + + var res = newSeqUninitialized[byte](length) + await s.readExactly(addr res[0], res.len) + res method write*( self: MixEntryConnection, msg: seq[byte] @@ -112,6 +220,7 @@ proc shortLog*(self: MixEntryConnection): string {.raises: [].} = method closeImpl*( self: MixEntryConnection ): Future[void] {.async: (raises: [], raw: true).} = + self.incomingFut.cancelSoon() let fut = newFuture[void]() fut.complete() return fut @@ -123,22 +232,6 @@ when defined(libp2p_agents_metrics): proc setShortAgent*(self: MixEntryConnection, shortAgent: string) = discard -proc new*( - T: typedesc[MixEntryConnection], - srcMix: MixProtocol, - destination: Destination, - codec: string, - mixDialer: MixDialer, - params: Opt[MixParameters], -): T = - let instance = - T(destination: destination, codec: codec, mixDialer: mixDialer, params: params) - - when defined(libp2p_agents_metrics): - instance.shortAgent = connection.shortAgent - - instance - proc new*( T: typedesc[MixEntryConnection], srcMix: MixProtocol, @@ -154,7 +247,20 @@ proc new*( else: 0 - var sendDialerFunc = proc( + var instance = T() + instance.destination = destination + instance.codec = codec + instance.params = Opt.some(params) + + if expectReply: + instance.incoming = newAsyncQueue[seq[byte]]() + instance.replyReceivedFut = newFuture[void]() + let checkForIncoming = proc(): Future[void] {.async: (raises: [CancelledError]).} = + instance.cached = await instance.incoming.get() + instance.replyReceivedFut.complete() + instance.incomingFut = checkForIncoming() + + instance.mixDialer = proc( msg: seq[byte], codec: string, dest: Destination ): Future[void] {.async: (raises: [CancelledError, LPStreamError]).} = try: @@ -164,12 +270,17 @@ proc new*( else: (Opt.none(PeerId), Opt.some(MixDestination.init(dest.peerId, dest.address))) - await srcMix.anonymizeLocalProtocolSend(msg, codec, peerId, destination, surbs) + await srcMix.anonymizeLocalProtocolSend( + instance.incoming, msg, codec, peerId, destination, surbs + ) except CatchableError as e: error "Error during execution of anonymizeLocalProtocolSend: ", err = e.msg return - T.new(srcMix, destination, codec, sendDialerFunc, Opt.some(params)) + when defined(libp2p_agents_metrics): + instance.shortAgent = connection.shortAgent + + instance proc toConnection*( srcMix: MixProtocol, diff --git a/mix/exit_layer.nim b/mix/exit_layer.nim index 4b59aa1..ae06763 100644 --- a/mix/exit_layer.nim +++ b/mix/exit_layer.nim @@ -79,7 +79,7 @@ proc reply( if not replyConn.isNil: await replyConn.close() try: - await replyConn.writeLp(response) + await replyConn.write(response) except LPStreamError as exc: error "could not reply", description = exc.msg mix_messages_error.inc(labelValues = ["ExitLayer", "REPLY_FAILED"]) diff --git a/mix/mix_protocol.nim b/mix/mix_protocol.nim index b31e7a2..2903075 100644 --- a/mix/mix_protocol.nim +++ b/mix/mix_protocol.nim @@ -12,6 +12,10 @@ import const MixProtocolID* = "/mix/1.0.0" +type ConnCreds = object + incoming: AsyncQueue[seq[byte]] + surbSKSeq: seq[(secret, key)] + type MixProtocol* = ref object of LPProtocol mixNodeInfo: MixNodeInfo pubNodeInfo: Table[PeerId, MixPubInfo] @@ -19,8 +23,8 @@ type MixProtocol* = ref object of LPProtocol tagManager: TagManager exitLayer: ExitLayer rng: ref HmacDrbgContext - # TODO: might require cleanup? - idToSKey: Table[array[surbIdLen, byte], seq[(secret, key)]] + # TODO: verify if this requires cleanup for cases in which response never arrives (and connection is closed) + connCreds: Table[I, ConnCreds] fwdRBehavior: TableRef[string, fwdReadBehaviorCb] proc hasFwdBehavior*(mixProto: MixProtocol, codec: string): bool = @@ -91,55 +95,91 @@ proc handleMixNodeConnection( mix_messages_error.inc(labelValues = ["Intermediate/Exit", "INVALID_SPHINX"]) return - let (nextHop, delay, processedPkt, status) = processSphinxPacket( - sphinxPacket, mixPrivKey, mixProto.tagManager - ).valueOr: + let processedSP = processSphinxPacket(sphinxPacket, mixPrivKey, mixProto.tagManager).valueOr: error "Failed to process Sphinx packet", err = error mix_messages_error.inc(labelValues = ["Intermediate/Exit", "INVALID_SPHINX"]) return - case status + case processedSP.status of Exit: - mix_messages_recvd.inc(labelValues = [$status]) + mix_messages_recvd.inc(labelValues = [$processedSP.status]) # This is the exit node, forward to destination - let msgChunk = MessageChunk.deserialize(processedPkt).valueOr: + let msgChunk = MessageChunk.deserialize(processedSP.messageChunk).valueOr: error "Deserialization failed", err = error - mix_messages_error.inc(labelValues = [$status, "INVALID_SPHINX"]) + mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"]) return let unpaddedMsg = unpadMessage(msgChunk).valueOr: error "Unpadding message failed", err = error - mix_messages_error.inc(labelValues = [$status, "INVALID_SPHINX"]) + mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"]) return let deserialized = MixMessage.deserialize(unpaddedMsg).valueOr: error "Deserialization failed", err = error - mix_messages_error.inc(labelValues = [$status, "INVALID_SPHINX"]) + mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"]) return let (surbs, message) = extractSURBs(deserialized.message).valueOr: error "Extracting surbs from payload failed", err = error - mix_messages_error.inc(labelValues = [$status, "INVALID_MSG_SURBS"]) + mix_messages_error.inc(labelValues = ["Exit", "INVALID_MSG_SURBS"]) return trace "Exit node - Received mix message", receiver = multiAddr, message = deserialized.message, codec = deserialized.codec - await mixProto.exitLayer.onMessage(deserialized.codec, message, nextHop, surbs) + await mixProto.exitLayer.onMessage( + deserialized.codec, message, processedSP.destination, surbs + ) - mix_messages_forwarded.inc(labelValues = [$status]) + mix_messages_forwarded.inc(labelValues = ["Exit"]) of Reply: - error "TODO: IMPLEMENT REPLY STATE" - # TODO: process reply at entry side + trace "# Reply", id = processedSP.id + try: + if not mixProto.connCreds.hasKey(processedSP.id): + mix_messages_error.inc(labelValues = ["Sender/Reply", "NO_CONN_FOUND"]) + return + + let connCred = mixProto.connCreds[processedSP.id] + mixProto.connCreds.del(processedSP.id) + + var couldProcessReply = false + var reply: seq[byte] + for sk in connCred.surbSKSeq: + let processReplyRes = processReply(sk[1], sk[0], processedSP.delta_prime) + if processReplyRes.isOk: + couldProcessReply = true + reply = processReplyRes.value() + break + + if couldProcessReply: + let msgChunk = MessageChunk.deserialize(reply).valueOr: + error "Deserialization failed", err = error + mix_messages_error.inc(labelValues = ["Reply", "INVALID_SPHINX"]) + return + + let unpaddedMsg = unpadMessage(msgChunk).valueOr: + error "Unpadding message failed", err = error + mix_messages_error.inc(labelValues = ["Reply", "INVALID_SPHINX"]) + return + + let deserialized = MixMessage.deserialize(unpaddedMsg).valueOr: + error "Deserialization failed", err = error + mix_messages_error.inc(labelValues = ["Reply", "INVALID_SPHINX"]) + return + + await connCred.incoming.put(deserialized.message) + else: + error "could not process reply", id = processedSP.id + except KeyError as ex: + doAssert false, "checked with hasKey" of Intermediate: trace "# Intermediate: ", multiAddr = multiAddr # Add delay - let delayMillis = (delay[0].int shl 8) or delay[1].int mix_messages_recvd.inc(labelValues = ["Intermediate"]) - await sleepAsync(milliseconds(delayMillis)) + await sleepAsync(milliseconds(processedSP.delayMs)) # Forward to next hop - let nextHopBytes = getHop(nextHop) + let nextHopBytes = getHop(processedSP.nextHop) let fullAddrStr = bytesToMultiAddr(nextHopBytes).valueOr: error "Failed to convert bytes to multiaddress", err = error @@ -170,7 +210,7 @@ proc handleMixNodeConnection( var nextHopConn: Connection try: nextHopConn = await mixProto.switch.dial(peerId, @[locationAddr], MixProtocolID) - await nextHopConn.writeLp(processedPkt) + await nextHopConn.writeLp(processedSP.serializedSphinxPacket) mix_messages_forwarded.inc(labelValues = ["Intermediate"]) except CatchableError as e: error "Failed to dial next hop: ", err = e.msg @@ -200,7 +240,10 @@ proc getMaxMessageSizeForCodec*( return ok(dataSize - totalLen) proc buildSurbs( - mixProto: MixProtocol, numSurbs: uint8, skipPeer: PeerId + mixProto: MixProtocol, + incoming: AsyncQueue[seq[byte]], + numSurbs: uint8, + exitPeerId: PeerId, ): Result[seq[SURB], string] = var response: seq[SURB] var surbSK: seq[(secret, key)] = @[] @@ -225,6 +268,13 @@ proc buildSurbs( randPeerId: PeerId availableIndices = toSeq(0 ..< numMixNodes) + # Remove exit node from nodes to consider for surbs + let index = pubNodeInfoKeys.find(exitPeerId) + if index != -1: + availableIndices.del(index) + else: + return err("could not find exit node") + var i = 0 while i < L: let (multiAddr, mixPubKey, delayMillisec) = @@ -233,9 +283,6 @@ proc buildSurbs( return err("failed to generate random num: " & error) let selectedIndex = availableIndices[randomIndexPosition] randPeerId = pubNodeInfoKeys[selectedIndex] - if randPeerId == skipPeer: - continue - availableIndices.del(randomIndexPosition) debug "Selected mix node for surbs: ", indexInPath = i, peerId = randPeerId let mixPubInfo = getMixPubInfo(mixProto.pubNodeInfo.getOrDefault(randPeerId)) @@ -269,14 +316,18 @@ proc buildSurbs( response.add(surb) if surbSK.len != 0: - mixProto.idToSKey[id] = surbSK + mixProto.connCreds[id] = ConnCreds(surbSKSeq: surbSK, incoming: incoming) return ok(response) proc prepareMsgWithSurbs( - mixProto: MixProtocol, msg: seq[byte], numSurbs: uint8 = 0, skipPeer: PeerId + mixProto: MixProtocol, + incoming: AsyncQueue[seq[byte]], + msg: seq[byte], + numSurbs: uint8 = 0, + exitPeerId: PeerId, ): Result[seq[byte], string] = - let surbs = buildSurbs(mixProto, numSurbs, skipPeer).valueOr: + let surbs = mixProto.buildSurbs(incoming, numSurbs, exitPeerId).valueOr: return err(error) let serialized = ?serializeMessageWithSURBs(msg, surbs) @@ -352,6 +403,7 @@ proc `$`*(d: MixDestination): string = proc anonymizeLocalProtocolSend*( mixProto: MixProtocol, + incoming: AsyncQueue[seq[byte]], msg: seq[byte], codec: string, destPeerId: Opt[PeerId], @@ -374,7 +426,7 @@ proc anonymizeLocalProtocolSend*( publicKeys: seq[FieldElement] = @[] hop: seq[Hop] = @[] delay: seq[seq[byte]] = @[] - exitNode: PeerId + exitPeerId: PeerId # Select L mix nodes at random let numMixNodes = mixProto.pubNodeInfo.len @@ -410,7 +462,7 @@ proc anonymizeLocalProtocolSend*( while i < L: if destPeerId.isSome and i == L - 1: randPeerId = destPeerId.value() - exitNode = destPeerId.value() + exitPeerId = destPeerId.value() else: let randomIndexPosition = cryptoRandomInt(availableIndices.len).valueOr: error "Failed to genanrate random number", err = error @@ -426,7 +478,7 @@ proc anonymizeLocalProtocolSend*( continue # Last hop will be the exit node that will forward the request if i == L - 1: - exitNode = randPeerId + exitPeerId = randPeerId debug "Selected mix node: ", indexInPath = i, peerId = randPeerId @@ -469,7 +521,7 @@ proc anonymizeLocalProtocolSend*( else: Hop() - let msgWithSurbs = prepareMsgWithSurbs(mixProto, msg, numSurbs, exitNode).valueOr: + let msgWithSurbs = mixProto.prepareMsgWithSurbs(incoming, msg, numSurbs, exitPeerId).valueOr: error "Could not prepend SURBs", err = error return diff --git a/mix/reply_connection.nim b/mix/reply_connection.nim index 4ae0e66..6e2df0e 100644 --- a/mix/reply_connection.nim +++ b/mix/reply_connection.nim @@ -1,4 +1,4 @@ -import hashes, chronos, stew/byteutils, results, chronicles +import hashes, chronos, results, chronicles import libp2p/stream/connection import libp2p import ./[serialization] @@ -35,16 +35,6 @@ method readLp*( method write*( self: MixReplyConnection, msg: seq[byte] -): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = - self.mixReplyDialer(self.surbs, msg) - -proc write*( - self: MixReplyConnection, msg: string -): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = - self.write(msg.toBytes()) - -method writeLp*( - self: MixReplyConnection, msg: openArray[byte] ): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = if msg.len() > dataSize: let fut = newFuture[void]() @@ -53,25 +43,38 @@ method writeLp*( ) return fut - var - vbytes: seq[byte] = @[] - value = msg.len().uint64 + self.mixReplyDialer(self.surbs, msg) - while value >= 128: - vbytes.add(byte((value and 127) or 128)) - value = value shr 7 - vbytes.add(byte(value)) +proc write*( + self: MixReplyConnection, msg: string +): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = + let fut = newFuture[void]() + fut.fail( + newException(LPStreamError, "write(string) not implemented for MixReplyConnection") + ) + return fut - var buf = newSeqUninitialized[byte](msg.len() + vbytes.len) - buf[0 ..< vbytes.len] = vbytes.toOpenArray(0, vbytes.len - 1) - buf[vbytes.len ..< buf.len] = msg - - self.mixReplyDialer(self.surbs, @buf) +method writeLp*( + self: MixReplyConnection, msg: openArray[byte] +): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = + let fut = newFuture[void]() + fut.fail( + newException( + LPStreamError, "writeLp(seq[byte]) not implemented for MixReplyConnection" + ) + ) + return fut method writeLp*( self: MixReplyConnection, msg: string ): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} = - self.writeLp(msg.toOpenArrayByte(0, msg.high)) + let fut = newFuture[void]() + fut.fail( + newException( + LPStreamError, "writeLp(string) not implemented for MixReplyConnection" + ) + ) + return fut proc shortLog*(self: MixReplyConnection): string {.raises: [].} = "[MixReplyConnection]" diff --git a/mix/sphinx.nim b/mix/sphinx.nim index 371171d..a9158d4 100644 --- a/mix/sphinx.nim +++ b/mix/sphinx.nim @@ -1,6 +1,6 @@ import results, sequtils import nimcrypto/sysrand -import ./[config, crypto, curve25519, serialization, tag_manager] +import ./[config, crypto, curve25519, serialization, tag_manager, utils] # Define possible outcomes of processing a Sphinx packet type ProcessingStatus* = enum @@ -123,7 +123,7 @@ proc computeBetaGamma( delay: openArray[seq[byte]], destHop: Hop, id: I, -): Result[(seq[byte], seq[byte]), string] = # TODO: name tuples +): Result[tuple[beta: seq[byte], gamma: seq[byte]], string] = let sLen = s.len var beta: seq[byte] @@ -161,7 +161,7 @@ proc computeBetaGamma( gamma = toSeq(hmac(mac_key, beta)) - return ok((beta, gamma)) + return ok((beta: beta, gamma: gamma)) # Function to compute deltas proc computeDelta(s: seq[seq[byte]], msg: Message): Result[seq[byte], string] = @@ -291,9 +291,24 @@ proc wrapInSphinxPacket*( return ok(serialized) +type ProcessedSphinxPacket* = object + case status*: ProcessingStatus + of ProcessingStatus.Exit: + destination*: Hop + messageChunk*: seq[byte] + of ProcessingStatus.Intermediate: + nextHop*: Hop + delayMs*: int + serializedSphinxPacket*: seq[byte] + of ProcessingStatus.Reply: + id*: I + delta_prime*: seq[byte] + else: + discard + proc processSphinxPacket*( sphinxPacket: SphinxPacket, privateKey: FieldElement, tm: var TagManager -): Result[(Hop, seq[byte], seq[byte], ProcessingStatus), string] = # TODO: named touple +): Result[ProcessedSphinxPacket, string] = let (header, payload) = sphinxPacket.getSphinxPacket() (alpha, beta, gamma) = getHeader(header) @@ -308,14 +323,14 @@ proc processSphinxPacket*( # Check if the tag has been seen if isTagSeen(tm, s): - return ok((Hop(), @[], @[], Duplicate)) + return ok(ProcessedSphinxPacket(status: Duplicate)) # Compute MAC let mac_key = kdf(deriveKeyMaterial("mac_key", sBytes)) if not (toSeq(hmac(mac_key, beta)) == gamma): # If MAC not verified - return ok((Hop(), @[], @[], InvalidMAC)) + return ok(ProcessedSphinxPacket(status: InvalidMAC)) # Store the tag as seen addTag(tm, s) @@ -350,14 +365,18 @@ proc processSphinxPacket*( let msg = Message.deserialize(delta_prime).valueOr: return err("Message deserialization error: " & error) let content = msg.getContent() - return - ok((hop, B[addrSize .. ((t * k) - 1)], content[0 .. messageSize - 1], Exit)) + return ok( + ProcessedSphinxPacket( + status: Exit, destination: hop, messageChunk: content[0 .. messageSize - 1] + ) + ) else: return err("delta_prime should be all zeros") elif B[0 .. (t * k) - 1] == newSeq[byte](t * k): - let I = B[(t * k) .. (((t + 1) * k) - 1)] - # TODO: return I - return ok((hop, B[addrSize .. ((t * k) - 1)], delta_prime, Reply)) + let idSeq = B[(t * k) .. (((t + 1) * k) - 1)] + var id: I + copyMem(addr id[0], unsafeAddr idSeq[0], k) + return ok(ProcessedSphinxPacket(status: Reply, id: id, delta_prime: delta_prime)) else: # Extract routing information from B let routingInfo = RoutingInfo.deserialize(B).valueOr: @@ -383,4 +402,11 @@ proc processSphinxPacket*( let serializedSP = sphinxPkt.serialize().valueOr: return err("Sphinx packet serialization error: " & error) - return ok((address, delay, serializedSP, Intermediate)) + return ok( + ProcessedSphinxPacket( + status: Intermediate, + nextHop: address, + delayMs: (?bytesToUInt16(delay)).int, + serializedSphinxPacket: serializedSP, + ) + ) diff --git a/tests/test_sphinx.nim b/tests/test_sphinx.nim index 1854ed8..6ac8a53 100644 --- a/tests/test_sphinx.nim +++ b/tests/test_sphinx.nim @@ -94,18 +94,19 @@ suite "Sphinx Tests": if res1.isErr: error "Error in Sphinx processing", err = res1.error fail() - let (address1, delay1, processedPacket1Bytes, status1) = res1.get() + let processedSP1 = res1.get() - if status1 != Intermediate: + if processedSP1.status != Intermediate: error "Processing status should be Intermediate" fail() - if processedPacket1Bytes.len != packetSize: + if processedSP1.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP1.serializedSphinxPacket.len), expected_len = $packetSize fail() - let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes) + let processedPacket1Res = + SphinxPacket.deserialize(processedSP1.serializedSphinxPacket) if processedPacket1Res.isErr: error "Sphinx wrap error", err = processedPacket1Res.error fail() @@ -115,18 +116,19 @@ suite "Sphinx Tests": if res2.isErr: error "Error in Sphinx processing", err = res2.error fail() - let (address2, delay2, processedPacket2Bytes, status2) = res2.get() + let processedSP2 = res2.get() - if status2 != Intermediate: + if processedSP2.status != Intermediate: error "Processing status should be Intermediate" fail() - if processedPacket2Bytes.len != packetSize: + if processedSP2.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP2.serializedSphinxPacket.len), expected_len = $packetSize fail() - let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes) + let processedPacket2Res = + SphinxPacket.deserialize(processedSP2.serializedSphinxPacket) if processedPacket2Res.isErr: error "Sphinx wrap error", err = processedPacket2Res.error fail() @@ -136,13 +138,13 @@ suite "Sphinx Tests": if res3.isErr: error "Error in Sphinx processing", err = res3.error fail() - let (address3, delay3, processedPacket3, status3) = res3.get() + let processedSP3 = res3.get() - if status3 != Exit: - error "Processing status should be Exit", status3 + if processedSP3.status != Exit: + error "Processing status should be Exit", status = processedSP3.status fail() - if processedPacket3 != message.getContent(): + if processedSP3.messageChunk != message.getContent(): error "Packet processing failed" fail() @@ -181,9 +183,9 @@ suite "Sphinx Tests": if res.isErr: error "Error in Sphinx processing", err = res.error fail() - let (_, _, _, status) = res.get() + let invalidMacPkt = res.get() - if status != InvalidMAC: + if invalidMacPkt.status != InvalidMAC: error "Processing status should be InvalidMAC" fail() @@ -211,9 +213,9 @@ suite "Sphinx Tests": if res1.isErr: error "Error in Sphinx processing", err = res1.error fail() - let (_, _, _, status1) = res1.get() + let processedSP1 = res1.get() - if status1 != Intermediate: + if processedSP1.status != Intermediate: error "Processing status should be Intermediate" fail() @@ -221,9 +223,9 @@ suite "Sphinx Tests": if res2.isErr: error "Error in Sphinx processing", err = res2.error fail() - let (_, _, _, status2) = res2.get() + let processedSP2 = res2.get() - if status2 != Duplicate: + if processedSP2.status != Duplicate: error "Processing status should be Duplicate" fail() @@ -260,18 +262,20 @@ suite "Sphinx Tests": if res1.isErr: error "Error in Sphinx processing", err = res1.error fail() - let (address1, delay1, processedPacket1Bytes, status1) = res1.get() + let processedSP1 = res1.get() - if status1 != Intermediate: + if processedSP1.status != Intermediate: error "Processing status should be Intermediate" fail() - if processedPacket1Bytes.len != packetSize: + if processedSP1.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP1.serializedSphinxPacket.len), + expected_len = $packetSize fail() - let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes) + let processedPacket1Res = + SphinxPacket.deserialize(processedSP1.serializedSphinxPacket) if processedPacket1Res.isErr: error "Sphinx wrap error", err = processedPacket1Res.error fail() @@ -281,18 +285,20 @@ suite "Sphinx Tests": if res2.isErr: error "Error in Sphinx processing", err = res2.error fail() - let (address2, delay2, processedPacket2Bytes, status2) = res2.get() + let processedSP2 = res2.get() - if status2 != Intermediate: + if processedSP2.status != Intermediate: error "Processing status should be Intermediate" fail() - if processedPacket2Bytes.len != packetSize: + if processedSP2.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP2.serializedSphinxPacket.len), + expected_len = $packetSize fail() - let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes) + let processedPacket2Res = + SphinxPacket.deserialize(processedSP2.serializedSphinxPacket) if processedPacket2Res.isErr: error "Sphinx wrap error", err = processedPacket2Res.error fail() @@ -302,13 +308,13 @@ suite "Sphinx Tests": if res3.isErr: error "Error in Sphinx processing", err = res3.error fail() - let (address3, delay3, processedPacket3, status3) = res3.get() + let processedSP3 = res3.get() - if status3 != Exit: + if processedSP3.status != Exit: error "Processing status should be Exit" fail() - if processedPacket3 != paddedMessage: + if processedSP3.messageChunk != paddedMessage: error "Packet processing failed" fail() @@ -340,18 +346,19 @@ suite "Sphinx Tests": if res1.isErr: error "Error in Sphinx processing", err = res1.error fail() - let (address1, delay1, processedPacket1Bytes, status1) = res1.get() + let processedSP1 = res1.get() - if status1 != Intermediate: + if processedSP1.status != Intermediate: error "Processing status should be Intermediate" fail() - if processedPacket1Bytes.len != packetSize: + if processedSP1.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP1.serializedSphinxPacket.len), expected_len = $packetSize fail() - let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes) + let processedPacket1Res = + SphinxPacket.deserialize(processedSP1.serializedSphinxPacket) if processedPacket1Res.isErr: error "Sphinx wrap error", err = processedPacket1Res.error fail() @@ -361,18 +368,19 @@ suite "Sphinx Tests": if res2.isErr: error "Error in Sphinx processing", err = res2.error fail() - let (address2, delay2, processedPacket2Bytes, status2) = res2.get() + let processedSP2 = res2.get() - if status2 != Intermediate: - error "Processing status should be Success" + if processedSP2.status != Intermediate: + error "Processing status should be Intermediate" fail() - if processedPacket2Bytes.len != packetSize: + if processedSP2.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP2.serializedSphinxPacket.len), expected_len = $packetSize fail() - let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes) + let processedPacket2Res = + SphinxPacket.deserialize(processedSP2.serializedSphinxPacket) if processedPacket2Res.isErr: error "Sphinx wrap error", err = processedPacket2Res.error fail() @@ -382,13 +390,13 @@ suite "Sphinx Tests": if res3.isErr: error "Error in Sphinx processing", err = res3.error fail() - let (address3, delay3, processedPacket3, status3) = res3.get() + let processedSP3 = res3.get() - if status3 != Reply: + if processedSP3.status != Reply: error "Processing status should be Reply" fail() - let msgRes = processReply(surb.key, surb.secret.get(), processedPacket3) + let msgRes = processReply(surb.key, surb.secret.get(), processedSP3.delta_prime) if msgRes.isErr: error "Reply processing failed", err = msgRes.error let msg = msgRes.get() @@ -437,9 +445,9 @@ suite "Sphinx Tests": if res.isErr: error "Error in Sphinx processing", err = res.error fail() - let (_, _, _, status) = res.get() + let processedSP1 = res.get() - if status != InvalidMAC: + if processedSP1.status != InvalidMAC: error "Processing status should be InvalidMAC" fail() @@ -472,19 +480,19 @@ suite "Sphinx Tests": if res1.isErr: error "Error in Sphinx processing", err = res1.error fail() - let (_, _, _, status1) = res1.get() + let processedSP1 = res1.get() - if status1 != Intermediate: - error "Processing status should be Success" + if processedSP1.status != Intermediate: + error "Processing status should be Intermediate" fail() let res2 = processSphinxPacket(packet, privateKeys[0], tm) if res2.isErr: error "Error in Sphinx processing", err = res2.error fail() - let (_, _, _, status2) = res2.get() + let processedSP2 = res2.get() - if status2 != Duplicate: + if processedSP2.status != Duplicate: error "Processing status should be Duplicate" fail() @@ -525,18 +533,20 @@ suite "Sphinx Tests": if res1.isErr: error "Error in Sphinx processing", err = res1.error fail() - let (address1, delay1, processedPacket1Bytes, status1) = res1.get() + let processedSP1 = res1.get() - if status1 != Intermediate: - error "Processing status should be Success" + if processedSP1.status != Intermediate: + error "Processing status should be Intermediate" fail() - if processedPacket1Bytes.len != packetSize: + if processedSP1.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP1.serializedSphinxPacket.len), + expected_len = $packetSize fail() - let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes) + let processedPacket1Res = + SphinxPacket.deserialize(processedSP1.serializedSphinxPacket) if processedPacket1Res.isErr: error "Sphinx wrap error", err = processedPacket1Res.error fail() @@ -546,18 +556,20 @@ suite "Sphinx Tests": if res2.isErr: error "Error in Sphinx processing", err = res2.error fail() - let (address2, delay2, processedPacket2Bytes, status2) = res2.get() + let processedSP2 = res2.get() - if status2 != Intermediate: - error "Processing status should be Success" + if processedSP2.status != Intermediate: + error "Processing status should be Intermediate" fail() - if processedPacket2Bytes.len != packetSize: + if processedSP2.serializedSphinxPacket.len != packetSize: error "Packet length is not valid", - pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize + pkt_len = $(processedSP2.serializedSphinxPacket.len), + expected_len = $packetSize fail() - let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes) + let processedPacket2Res = + SphinxPacket.deserialize(processedSP2.serializedSphinxPacket) if processedPacket1Res.isErr: error "Sphinx wrap error", err = processedPacket2Res.error fail() @@ -567,13 +579,13 @@ suite "Sphinx Tests": if res3.isErr: error "Error in Sphinx processing", err = res3.error fail() - let (address3, delay3, processedPacket3, status3) = res3.get() + let processedSP3 = res3.get() - if status3 != Reply: + if processedSP3.status != Reply: error "Processing status should be Reply" fail() - let msgRes = processReply(surb.key, surb.secret.get(), processedPacket3) + let msgRes = processReply(surb.key, surb.secret.get(), processedSP3.delta_prime) if msgRes.isErr: error "Reply processing failed", err = msgRes.error let msg = msgRes.get()