mirror of
https://github.com/vacp2p/mix.git
synced 2026-01-09 02:38:00 -05:00
refactor: processSphinxPacket and process replies and return in conn (#78)
This commit is contained in:
@@ -116,7 +116,7 @@ proc mixnetSimulation() {.async: (raises: [Exception]).} =
|
||||
|
||||
let response = await pingProto[senderIndex].ping(conn)
|
||||
|
||||
await sleepAsync(1.seconds)
|
||||
info "PING RESPONSE", response
|
||||
|
||||
deleteNodeInfoFolder()
|
||||
deletePubInfoFolder()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import hashes, chronos, stew/byteutils, results, chronicles
|
||||
import libp2p/stream/connection
|
||||
import libp2p/varint
|
||||
import ./mix_protocol
|
||||
import ./config
|
||||
from fragmentation import dataSize
|
||||
@@ -44,27 +45,134 @@ type MixEntryConnection* = ref object of Connection
|
||||
mixDialer: MixDialer
|
||||
params: Opt[MixParameters]
|
||||
|
||||
incoming: AsyncQueue[seq[byte]]
|
||||
incomingFut: Future[void]
|
||||
replyReceivedFut: Future[void]
|
||||
cached: seq[byte]
|
||||
|
||||
method readOnce*(
|
||||
s: MixEntryConnection, pbytes: pointer, nbytes: int
|
||||
): Future[int] {.async: (raises: [CancelledError, LPStreamError]), public.} =
|
||||
if s.isEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
try:
|
||||
await s.replyReceivedFut
|
||||
s.isEof = true
|
||||
if s.cached.len == 0:
|
||||
raise newLPStreamEOFError()
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except LPStreamEOFError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise (ref LPStreamError)(msg: "error in readOnce: " & exc.msg, parent: exc)
|
||||
|
||||
let toRead = min(nbytes, s.cached.len)
|
||||
copyMem(pbytes, addr s.cached[0], toRead)
|
||||
s.cached = s.cached[toRead ..^ 1]
|
||||
return toRead
|
||||
|
||||
method readExactly*(
|
||||
self: MixEntryConnection, pbytes: pointer, nbytes: int
|
||||
s: MixEntryConnection, pbytes: pointer, nbytes: int
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError]), public.} =
|
||||
await sleepAsync(10.minutes) # TODO: implement readExactly
|
||||
raise
|
||||
newException(LPStreamError, "readExactly not implemented for MixEntryConnection")
|
||||
## Waits for `nbytes` to be available, then read
|
||||
## them and return them
|
||||
if s.atEof:
|
||||
var ch: char
|
||||
discard await s.readOnce(addr ch, 1)
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
if nbytes == 0:
|
||||
return
|
||||
|
||||
logScope:
|
||||
s
|
||||
nbytes = nbytes
|
||||
objName = s.objName
|
||||
|
||||
var pbuffer = cast[ptr UncheckedArray[byte]](pbytes)
|
||||
var read = 0
|
||||
while read < nbytes and not (s.atEof()):
|
||||
read += await s.readOnce(addr pbuffer[read], nbytes - read)
|
||||
|
||||
if read == 0:
|
||||
doAssert s.atEof()
|
||||
trace "couldn't read all bytes, stream EOF", s, nbytes, read
|
||||
# Re-readOnce to raise a more specific error than EOF
|
||||
# Raise EOF if it doesn't raise anything(shouldn't happen)
|
||||
discard await s.readOnce(addr pbuffer[read], nbytes - read)
|
||||
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
if read < nbytes:
|
||||
trace "couldn't read all bytes, incomplete data", s, nbytes, read
|
||||
raise newLPStreamIncompleteError()
|
||||
|
||||
method readLine*(
|
||||
self: MixEntryConnection, limit = 0, sep = "\r\n"
|
||||
s: MixEntryConnection, limit = 0, sep = "\r\n"
|
||||
): Future[string] {.async: (raises: [CancelledError, LPStreamError]), public.} =
|
||||
raise newException(LPStreamError, "readLine not implemented for MixEntryConnection")
|
||||
## Reads up to `limit` bytes are read, or a `sep` is found
|
||||
# TODO replace with something that exploits buffering better
|
||||
var lim = if limit <= 0: -1 else: limit
|
||||
var state = 0
|
||||
|
||||
while true:
|
||||
var ch: char
|
||||
await readExactly(s, addr ch, 1)
|
||||
|
||||
if sep[state] == ch:
|
||||
inc(state)
|
||||
if state == len(sep):
|
||||
break
|
||||
else:
|
||||
state = 0
|
||||
if limit > 0:
|
||||
let missing = min(state, lim - len(result) - 1)
|
||||
result.add(sep[0 ..< missing])
|
||||
else:
|
||||
result.add(sep[0 ..< state])
|
||||
|
||||
result.add(ch)
|
||||
if len(result) == lim:
|
||||
break
|
||||
|
||||
method readVarint*(
|
||||
self: MixEntryConnection
|
||||
conn: MixEntryConnection
|
||||
): Future[uint64] {.async: (raises: [CancelledError, LPStreamError]), public.} =
|
||||
raise newException(LPStreamError, "readVarint not implemented for MixEntryConnection")
|
||||
var buffer: array[10, byte]
|
||||
|
||||
for i in 0 ..< len(buffer):
|
||||
await conn.readExactly(addr buffer[i], 1)
|
||||
|
||||
var
|
||||
varint: uint64
|
||||
length: int
|
||||
let res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint)
|
||||
if res.isOk():
|
||||
return varint
|
||||
if res.error() != VarintError.Incomplete:
|
||||
break
|
||||
if true: # can't end with a raise apparently
|
||||
raise (ref InvalidVarintError)(msg: "Cannot parse varint")
|
||||
|
||||
method readLp*(
|
||||
self: MixEntryConnection, maxSize: int
|
||||
s: MixEntryConnection, maxSize: int
|
||||
): Future[seq[byte]] {.async: (raises: [CancelledError, LPStreamError]), public.} =
|
||||
raise newException(LPStreamError, "readLp not implemented for MixEntryConnection")
|
||||
## read length prefixed msg, with the length encoded as a varint
|
||||
let
|
||||
length = await s.readVarint()
|
||||
maxLen = uint64(if maxSize < 0: int.high else: maxSize)
|
||||
|
||||
if length > maxLen:
|
||||
raise (ref MaxSizeError)(msg: "Message exceeds maximum length")
|
||||
|
||||
if length == 0:
|
||||
return
|
||||
|
||||
var res = newSeqUninitialized[byte](length)
|
||||
await s.readExactly(addr res[0], res.len)
|
||||
res
|
||||
|
||||
method write*(
|
||||
self: MixEntryConnection, msg: seq[byte]
|
||||
@@ -112,6 +220,7 @@ proc shortLog*(self: MixEntryConnection): string {.raises: [].} =
|
||||
method closeImpl*(
|
||||
self: MixEntryConnection
|
||||
): Future[void] {.async: (raises: [], raw: true).} =
|
||||
self.incomingFut.cancelSoon()
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
return fut
|
||||
@@ -123,22 +232,6 @@ when defined(libp2p_agents_metrics):
|
||||
proc setShortAgent*(self: MixEntryConnection, shortAgent: string) =
|
||||
discard
|
||||
|
||||
proc new*(
|
||||
T: typedesc[MixEntryConnection],
|
||||
srcMix: MixProtocol,
|
||||
destination: Destination,
|
||||
codec: string,
|
||||
mixDialer: MixDialer,
|
||||
params: Opt[MixParameters],
|
||||
): T =
|
||||
let instance =
|
||||
T(destination: destination, codec: codec, mixDialer: mixDialer, params: params)
|
||||
|
||||
when defined(libp2p_agents_metrics):
|
||||
instance.shortAgent = connection.shortAgent
|
||||
|
||||
instance
|
||||
|
||||
proc new*(
|
||||
T: typedesc[MixEntryConnection],
|
||||
srcMix: MixProtocol,
|
||||
@@ -154,7 +247,20 @@ proc new*(
|
||||
else:
|
||||
0
|
||||
|
||||
var sendDialerFunc = proc(
|
||||
var instance = T()
|
||||
instance.destination = destination
|
||||
instance.codec = codec
|
||||
instance.params = Opt.some(params)
|
||||
|
||||
if expectReply:
|
||||
instance.incoming = newAsyncQueue[seq[byte]]()
|
||||
instance.replyReceivedFut = newFuture[void]()
|
||||
let checkForIncoming = proc(): Future[void] {.async: (raises: [CancelledError]).} =
|
||||
instance.cached = await instance.incoming.get()
|
||||
instance.replyReceivedFut.complete()
|
||||
instance.incomingFut = checkForIncoming()
|
||||
|
||||
instance.mixDialer = proc(
|
||||
msg: seq[byte], codec: string, dest: Destination
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
try:
|
||||
@@ -164,12 +270,17 @@ proc new*(
|
||||
else:
|
||||
(Opt.none(PeerId), Opt.some(MixDestination.init(dest.peerId, dest.address)))
|
||||
|
||||
await srcMix.anonymizeLocalProtocolSend(msg, codec, peerId, destination, surbs)
|
||||
await srcMix.anonymizeLocalProtocolSend(
|
||||
instance.incoming, msg, codec, peerId, destination, surbs
|
||||
)
|
||||
except CatchableError as e:
|
||||
error "Error during execution of anonymizeLocalProtocolSend: ", err = e.msg
|
||||
return
|
||||
|
||||
T.new(srcMix, destination, codec, sendDialerFunc, Opt.some(params))
|
||||
when defined(libp2p_agents_metrics):
|
||||
instance.shortAgent = connection.shortAgent
|
||||
|
||||
instance
|
||||
|
||||
proc toConnection*(
|
||||
srcMix: MixProtocol,
|
||||
|
||||
@@ -79,7 +79,7 @@ proc reply(
|
||||
if not replyConn.isNil:
|
||||
await replyConn.close()
|
||||
try:
|
||||
await replyConn.writeLp(response)
|
||||
await replyConn.write(response)
|
||||
except LPStreamError as exc:
|
||||
error "could not reply", description = exc.msg
|
||||
mix_messages_error.inc(labelValues = ["ExitLayer", "REPLY_FAILED"])
|
||||
|
||||
@@ -12,6 +12,10 @@ import
|
||||
|
||||
const MixProtocolID* = "/mix/1.0.0"
|
||||
|
||||
type ConnCreds = object
|
||||
incoming: AsyncQueue[seq[byte]]
|
||||
surbSKSeq: seq[(secret, key)]
|
||||
|
||||
type MixProtocol* = ref object of LPProtocol
|
||||
mixNodeInfo: MixNodeInfo
|
||||
pubNodeInfo: Table[PeerId, MixPubInfo]
|
||||
@@ -19,8 +23,8 @@ type MixProtocol* = ref object of LPProtocol
|
||||
tagManager: TagManager
|
||||
exitLayer: ExitLayer
|
||||
rng: ref HmacDrbgContext
|
||||
# TODO: might require cleanup?
|
||||
idToSKey: Table[array[surbIdLen, byte], seq[(secret, key)]]
|
||||
# TODO: verify if this requires cleanup for cases in which response never arrives (and connection is closed)
|
||||
connCreds: Table[I, ConnCreds]
|
||||
fwdRBehavior: TableRef[string, fwdReadBehaviorCb]
|
||||
|
||||
proc hasFwdBehavior*(mixProto: MixProtocol, codec: string): bool =
|
||||
@@ -91,55 +95,91 @@ proc handleMixNodeConnection(
|
||||
mix_messages_error.inc(labelValues = ["Intermediate/Exit", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
let (nextHop, delay, processedPkt, status) = processSphinxPacket(
|
||||
sphinxPacket, mixPrivKey, mixProto.tagManager
|
||||
).valueOr:
|
||||
let processedSP = processSphinxPacket(sphinxPacket, mixPrivKey, mixProto.tagManager).valueOr:
|
||||
error "Failed to process Sphinx packet", err = error
|
||||
mix_messages_error.inc(labelValues = ["Intermediate/Exit", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
case status
|
||||
case processedSP.status
|
||||
of Exit:
|
||||
mix_messages_recvd.inc(labelValues = [$status])
|
||||
mix_messages_recvd.inc(labelValues = [$processedSP.status])
|
||||
# This is the exit node, forward to destination
|
||||
let msgChunk = MessageChunk.deserialize(processedPkt).valueOr:
|
||||
let msgChunk = MessageChunk.deserialize(processedSP.messageChunk).valueOr:
|
||||
error "Deserialization failed", err = error
|
||||
mix_messages_error.inc(labelValues = [$status, "INVALID_SPHINX"])
|
||||
mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
let unpaddedMsg = unpadMessage(msgChunk).valueOr:
|
||||
error "Unpadding message failed", err = error
|
||||
mix_messages_error.inc(labelValues = [$status, "INVALID_SPHINX"])
|
||||
mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
let deserialized = MixMessage.deserialize(unpaddedMsg).valueOr:
|
||||
error "Deserialization failed", err = error
|
||||
mix_messages_error.inc(labelValues = [$status, "INVALID_SPHINX"])
|
||||
mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
let (surbs, message) = extractSURBs(deserialized.message).valueOr:
|
||||
error "Extracting surbs from payload failed", err = error
|
||||
mix_messages_error.inc(labelValues = [$status, "INVALID_MSG_SURBS"])
|
||||
mix_messages_error.inc(labelValues = ["Exit", "INVALID_MSG_SURBS"])
|
||||
return
|
||||
|
||||
trace "Exit node - Received mix message",
|
||||
receiver = multiAddr, message = deserialized.message, codec = deserialized.codec
|
||||
|
||||
await mixProto.exitLayer.onMessage(deserialized.codec, message, nextHop, surbs)
|
||||
await mixProto.exitLayer.onMessage(
|
||||
deserialized.codec, message, processedSP.destination, surbs
|
||||
)
|
||||
|
||||
mix_messages_forwarded.inc(labelValues = [$status])
|
||||
mix_messages_forwarded.inc(labelValues = ["Exit"])
|
||||
of Reply:
|
||||
error "TODO: IMPLEMENT REPLY STATE"
|
||||
# TODO: process reply at entry side
|
||||
trace "# Reply", id = processedSP.id
|
||||
try:
|
||||
if not mixProto.connCreds.hasKey(processedSP.id):
|
||||
mix_messages_error.inc(labelValues = ["Sender/Reply", "NO_CONN_FOUND"])
|
||||
return
|
||||
|
||||
let connCred = mixProto.connCreds[processedSP.id]
|
||||
mixProto.connCreds.del(processedSP.id)
|
||||
|
||||
var couldProcessReply = false
|
||||
var reply: seq[byte]
|
||||
for sk in connCred.surbSKSeq:
|
||||
let processReplyRes = processReply(sk[1], sk[0], processedSP.delta_prime)
|
||||
if processReplyRes.isOk:
|
||||
couldProcessReply = true
|
||||
reply = processReplyRes.value()
|
||||
break
|
||||
|
||||
if couldProcessReply:
|
||||
let msgChunk = MessageChunk.deserialize(reply).valueOr:
|
||||
error "Deserialization failed", err = error
|
||||
mix_messages_error.inc(labelValues = ["Reply", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
let unpaddedMsg = unpadMessage(msgChunk).valueOr:
|
||||
error "Unpadding message failed", err = error
|
||||
mix_messages_error.inc(labelValues = ["Reply", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
let deserialized = MixMessage.deserialize(unpaddedMsg).valueOr:
|
||||
error "Deserialization failed", err = error
|
||||
mix_messages_error.inc(labelValues = ["Reply", "INVALID_SPHINX"])
|
||||
return
|
||||
|
||||
await connCred.incoming.put(deserialized.message)
|
||||
else:
|
||||
error "could not process reply", id = processedSP.id
|
||||
except KeyError as ex:
|
||||
doAssert false, "checked with hasKey"
|
||||
of Intermediate:
|
||||
trace "# Intermediate: ", multiAddr = multiAddr
|
||||
# Add delay
|
||||
let delayMillis = (delay[0].int shl 8) or delay[1].int
|
||||
mix_messages_recvd.inc(labelValues = ["Intermediate"])
|
||||
await sleepAsync(milliseconds(delayMillis))
|
||||
await sleepAsync(milliseconds(processedSP.delayMs))
|
||||
|
||||
# Forward to next hop
|
||||
let nextHopBytes = getHop(nextHop)
|
||||
let nextHopBytes = getHop(processedSP.nextHop)
|
||||
|
||||
let fullAddrStr = bytesToMultiAddr(nextHopBytes).valueOr:
|
||||
error "Failed to convert bytes to multiaddress", err = error
|
||||
@@ -170,7 +210,7 @@ proc handleMixNodeConnection(
|
||||
var nextHopConn: Connection
|
||||
try:
|
||||
nextHopConn = await mixProto.switch.dial(peerId, @[locationAddr], MixProtocolID)
|
||||
await nextHopConn.writeLp(processedPkt)
|
||||
await nextHopConn.writeLp(processedSP.serializedSphinxPacket)
|
||||
mix_messages_forwarded.inc(labelValues = ["Intermediate"])
|
||||
except CatchableError as e:
|
||||
error "Failed to dial next hop: ", err = e.msg
|
||||
@@ -200,7 +240,10 @@ proc getMaxMessageSizeForCodec*(
|
||||
return ok(dataSize - totalLen)
|
||||
|
||||
proc buildSurbs(
|
||||
mixProto: MixProtocol, numSurbs: uint8, skipPeer: PeerId
|
||||
mixProto: MixProtocol,
|
||||
incoming: AsyncQueue[seq[byte]],
|
||||
numSurbs: uint8,
|
||||
exitPeerId: PeerId,
|
||||
): Result[seq[SURB], string] =
|
||||
var response: seq[SURB]
|
||||
var surbSK: seq[(secret, key)] = @[]
|
||||
@@ -225,6 +268,13 @@ proc buildSurbs(
|
||||
randPeerId: PeerId
|
||||
availableIndices = toSeq(0 ..< numMixNodes)
|
||||
|
||||
# Remove exit node from nodes to consider for surbs
|
||||
let index = pubNodeInfoKeys.find(exitPeerId)
|
||||
if index != -1:
|
||||
availableIndices.del(index)
|
||||
else:
|
||||
return err("could not find exit node")
|
||||
|
||||
var i = 0
|
||||
while i < L:
|
||||
let (multiAddr, mixPubKey, delayMillisec) =
|
||||
@@ -233,9 +283,6 @@ proc buildSurbs(
|
||||
return err("failed to generate random num: " & error)
|
||||
let selectedIndex = availableIndices[randomIndexPosition]
|
||||
randPeerId = pubNodeInfoKeys[selectedIndex]
|
||||
if randPeerId == skipPeer:
|
||||
continue
|
||||
|
||||
availableIndices.del(randomIndexPosition)
|
||||
debug "Selected mix node for surbs: ", indexInPath = i, peerId = randPeerId
|
||||
let mixPubInfo = getMixPubInfo(mixProto.pubNodeInfo.getOrDefault(randPeerId))
|
||||
@@ -269,14 +316,18 @@ proc buildSurbs(
|
||||
response.add(surb)
|
||||
|
||||
if surbSK.len != 0:
|
||||
mixProto.idToSKey[id] = surbSK
|
||||
mixProto.connCreds[id] = ConnCreds(surbSKSeq: surbSK, incoming: incoming)
|
||||
|
||||
return ok(response)
|
||||
|
||||
proc prepareMsgWithSurbs(
|
||||
mixProto: MixProtocol, msg: seq[byte], numSurbs: uint8 = 0, skipPeer: PeerId
|
||||
mixProto: MixProtocol,
|
||||
incoming: AsyncQueue[seq[byte]],
|
||||
msg: seq[byte],
|
||||
numSurbs: uint8 = 0,
|
||||
exitPeerId: PeerId,
|
||||
): Result[seq[byte], string] =
|
||||
let surbs = buildSurbs(mixProto, numSurbs, skipPeer).valueOr:
|
||||
let surbs = mixProto.buildSurbs(incoming, numSurbs, exitPeerId).valueOr:
|
||||
return err(error)
|
||||
|
||||
let serialized = ?serializeMessageWithSURBs(msg, surbs)
|
||||
@@ -352,6 +403,7 @@ proc `$`*(d: MixDestination): string =
|
||||
|
||||
proc anonymizeLocalProtocolSend*(
|
||||
mixProto: MixProtocol,
|
||||
incoming: AsyncQueue[seq[byte]],
|
||||
msg: seq[byte],
|
||||
codec: string,
|
||||
destPeerId: Opt[PeerId],
|
||||
@@ -374,7 +426,7 @@ proc anonymizeLocalProtocolSend*(
|
||||
publicKeys: seq[FieldElement] = @[]
|
||||
hop: seq[Hop] = @[]
|
||||
delay: seq[seq[byte]] = @[]
|
||||
exitNode: PeerId
|
||||
exitPeerId: PeerId
|
||||
|
||||
# Select L mix nodes at random
|
||||
let numMixNodes = mixProto.pubNodeInfo.len
|
||||
@@ -410,7 +462,7 @@ proc anonymizeLocalProtocolSend*(
|
||||
while i < L:
|
||||
if destPeerId.isSome and i == L - 1:
|
||||
randPeerId = destPeerId.value()
|
||||
exitNode = destPeerId.value()
|
||||
exitPeerId = destPeerId.value()
|
||||
else:
|
||||
let randomIndexPosition = cryptoRandomInt(availableIndices.len).valueOr:
|
||||
error "Failed to genanrate random number", err = error
|
||||
@@ -426,7 +478,7 @@ proc anonymizeLocalProtocolSend*(
|
||||
continue
|
||||
# Last hop will be the exit node that will forward the request
|
||||
if i == L - 1:
|
||||
exitNode = randPeerId
|
||||
exitPeerId = randPeerId
|
||||
|
||||
debug "Selected mix node: ", indexInPath = i, peerId = randPeerId
|
||||
|
||||
@@ -469,7 +521,7 @@ proc anonymizeLocalProtocolSend*(
|
||||
else:
|
||||
Hop()
|
||||
|
||||
let msgWithSurbs = prepareMsgWithSurbs(mixProto, msg, numSurbs, exitNode).valueOr:
|
||||
let msgWithSurbs = mixProto.prepareMsgWithSurbs(incoming, msg, numSurbs, exitPeerId).valueOr:
|
||||
error "Could not prepend SURBs", err = error
|
||||
return
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import hashes, chronos, stew/byteutils, results, chronicles
|
||||
import hashes, chronos, results, chronicles
|
||||
import libp2p/stream/connection
|
||||
import libp2p
|
||||
import ./[serialization]
|
||||
@@ -35,16 +35,6 @@ method readLp*(
|
||||
|
||||
method write*(
|
||||
self: MixReplyConnection, msg: seq[byte]
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
|
||||
self.mixReplyDialer(self.surbs, msg)
|
||||
|
||||
proc write*(
|
||||
self: MixReplyConnection, msg: string
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
|
||||
self.write(msg.toBytes())
|
||||
|
||||
method writeLp*(
|
||||
self: MixReplyConnection, msg: openArray[byte]
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
|
||||
if msg.len() > dataSize:
|
||||
let fut = newFuture[void]()
|
||||
@@ -53,25 +43,38 @@ method writeLp*(
|
||||
)
|
||||
return fut
|
||||
|
||||
var
|
||||
vbytes: seq[byte] = @[]
|
||||
value = msg.len().uint64
|
||||
self.mixReplyDialer(self.surbs, msg)
|
||||
|
||||
while value >= 128:
|
||||
vbytes.add(byte((value and 127) or 128))
|
||||
value = value shr 7
|
||||
vbytes.add(byte(value))
|
||||
proc write*(
|
||||
self: MixReplyConnection, msg: string
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
|
||||
let fut = newFuture[void]()
|
||||
fut.fail(
|
||||
newException(LPStreamError, "write(string) not implemented for MixReplyConnection")
|
||||
)
|
||||
return fut
|
||||
|
||||
var buf = newSeqUninitialized[byte](msg.len() + vbytes.len)
|
||||
buf[0 ..< vbytes.len] = vbytes.toOpenArray(0, vbytes.len - 1)
|
||||
buf[vbytes.len ..< buf.len] = msg
|
||||
|
||||
self.mixReplyDialer(self.surbs, @buf)
|
||||
method writeLp*(
|
||||
self: MixReplyConnection, msg: openArray[byte]
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
|
||||
let fut = newFuture[void]()
|
||||
fut.fail(
|
||||
newException(
|
||||
LPStreamError, "writeLp(seq[byte]) not implemented for MixReplyConnection"
|
||||
)
|
||||
)
|
||||
return fut
|
||||
|
||||
method writeLp*(
|
||||
self: MixReplyConnection, msg: string
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
|
||||
self.writeLp(msg.toOpenArrayByte(0, msg.high))
|
||||
let fut = newFuture[void]()
|
||||
fut.fail(
|
||||
newException(
|
||||
LPStreamError, "writeLp(string) not implemented for MixReplyConnection"
|
||||
)
|
||||
)
|
||||
return fut
|
||||
|
||||
proc shortLog*(self: MixReplyConnection): string {.raises: [].} =
|
||||
"[MixReplyConnection]"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import results, sequtils
|
||||
import nimcrypto/sysrand
|
||||
import ./[config, crypto, curve25519, serialization, tag_manager]
|
||||
import ./[config, crypto, curve25519, serialization, tag_manager, utils]
|
||||
|
||||
# Define possible outcomes of processing a Sphinx packet
|
||||
type ProcessingStatus* = enum
|
||||
@@ -123,7 +123,7 @@ proc computeBetaGamma(
|
||||
delay: openArray[seq[byte]],
|
||||
destHop: Hop,
|
||||
id: I,
|
||||
): Result[(seq[byte], seq[byte]), string] = # TODO: name tuples
|
||||
): Result[tuple[beta: seq[byte], gamma: seq[byte]], string] =
|
||||
let sLen = s.len
|
||||
var
|
||||
beta: seq[byte]
|
||||
@@ -161,7 +161,7 @@ proc computeBetaGamma(
|
||||
|
||||
gamma = toSeq(hmac(mac_key, beta))
|
||||
|
||||
return ok((beta, gamma))
|
||||
return ok((beta: beta, gamma: gamma))
|
||||
|
||||
# Function to compute deltas
|
||||
proc computeDelta(s: seq[seq[byte]], msg: Message): Result[seq[byte], string] =
|
||||
@@ -291,9 +291,24 @@ proc wrapInSphinxPacket*(
|
||||
|
||||
return ok(serialized)
|
||||
|
||||
type ProcessedSphinxPacket* = object
|
||||
case status*: ProcessingStatus
|
||||
of ProcessingStatus.Exit:
|
||||
destination*: Hop
|
||||
messageChunk*: seq[byte]
|
||||
of ProcessingStatus.Intermediate:
|
||||
nextHop*: Hop
|
||||
delayMs*: int
|
||||
serializedSphinxPacket*: seq[byte]
|
||||
of ProcessingStatus.Reply:
|
||||
id*: I
|
||||
delta_prime*: seq[byte]
|
||||
else:
|
||||
discard
|
||||
|
||||
proc processSphinxPacket*(
|
||||
sphinxPacket: SphinxPacket, privateKey: FieldElement, tm: var TagManager
|
||||
): Result[(Hop, seq[byte], seq[byte], ProcessingStatus), string] = # TODO: named touple
|
||||
): Result[ProcessedSphinxPacket, string] =
|
||||
let
|
||||
(header, payload) = sphinxPacket.getSphinxPacket()
|
||||
(alpha, beta, gamma) = getHeader(header)
|
||||
@@ -308,14 +323,14 @@ proc processSphinxPacket*(
|
||||
|
||||
# Check if the tag has been seen
|
||||
if isTagSeen(tm, s):
|
||||
return ok((Hop(), @[], @[], Duplicate))
|
||||
return ok(ProcessedSphinxPacket(status: Duplicate))
|
||||
|
||||
# Compute MAC
|
||||
let mac_key = kdf(deriveKeyMaterial("mac_key", sBytes))
|
||||
|
||||
if not (toSeq(hmac(mac_key, beta)) == gamma):
|
||||
# If MAC not verified
|
||||
return ok((Hop(), @[], @[], InvalidMAC))
|
||||
return ok(ProcessedSphinxPacket(status: InvalidMAC))
|
||||
|
||||
# Store the tag as seen
|
||||
addTag(tm, s)
|
||||
@@ -350,14 +365,18 @@ proc processSphinxPacket*(
|
||||
let msg = Message.deserialize(delta_prime).valueOr:
|
||||
return err("Message deserialization error: " & error)
|
||||
let content = msg.getContent()
|
||||
return
|
||||
ok((hop, B[addrSize .. ((t * k) - 1)], content[0 .. messageSize - 1], Exit))
|
||||
return ok(
|
||||
ProcessedSphinxPacket(
|
||||
status: Exit, destination: hop, messageChunk: content[0 .. messageSize - 1]
|
||||
)
|
||||
)
|
||||
else:
|
||||
return err("delta_prime should be all zeros")
|
||||
elif B[0 .. (t * k) - 1] == newSeq[byte](t * k):
|
||||
let I = B[(t * k) .. (((t + 1) * k) - 1)]
|
||||
# TODO: return I
|
||||
return ok((hop, B[addrSize .. ((t * k) - 1)], delta_prime, Reply))
|
||||
let idSeq = B[(t * k) .. (((t + 1) * k) - 1)]
|
||||
var id: I
|
||||
copyMem(addr id[0], unsafeAddr idSeq[0], k)
|
||||
return ok(ProcessedSphinxPacket(status: Reply, id: id, delta_prime: delta_prime))
|
||||
else:
|
||||
# Extract routing information from B
|
||||
let routingInfo = RoutingInfo.deserialize(B).valueOr:
|
||||
@@ -383,4 +402,11 @@ proc processSphinxPacket*(
|
||||
let serializedSP = sphinxPkt.serialize().valueOr:
|
||||
return err("Sphinx packet serialization error: " & error)
|
||||
|
||||
return ok((address, delay, serializedSP, Intermediate))
|
||||
return ok(
|
||||
ProcessedSphinxPacket(
|
||||
status: Intermediate,
|
||||
nextHop: address,
|
||||
delayMs: (?bytesToUInt16(delay)).int,
|
||||
serializedSphinxPacket: serializedSP,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -94,18 +94,19 @@ suite "Sphinx Tests":
|
||||
if res1.isErr:
|
||||
error "Error in Sphinx processing", err = res1.error
|
||||
fail()
|
||||
let (address1, delay1, processedPacket1Bytes, status1) = res1.get()
|
||||
let processedSP1 = res1.get()
|
||||
|
||||
if status1 != Intermediate:
|
||||
if processedSP1.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket1Bytes.len != packetSize:
|
||||
if processedSP1.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP1.serializedSphinxPacket.len), expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes)
|
||||
let processedPacket1Res =
|
||||
SphinxPacket.deserialize(processedSP1.serializedSphinxPacket)
|
||||
if processedPacket1Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket1Res.error
|
||||
fail()
|
||||
@@ -115,18 +116,19 @@ suite "Sphinx Tests":
|
||||
if res2.isErr:
|
||||
error "Error in Sphinx processing", err = res2.error
|
||||
fail()
|
||||
let (address2, delay2, processedPacket2Bytes, status2) = res2.get()
|
||||
let processedSP2 = res2.get()
|
||||
|
||||
if status2 != Intermediate:
|
||||
if processedSP2.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket2Bytes.len != packetSize:
|
||||
if processedSP2.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP2.serializedSphinxPacket.len), expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes)
|
||||
let processedPacket2Res =
|
||||
SphinxPacket.deserialize(processedSP2.serializedSphinxPacket)
|
||||
if processedPacket2Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket2Res.error
|
||||
fail()
|
||||
@@ -136,13 +138,13 @@ suite "Sphinx Tests":
|
||||
if res3.isErr:
|
||||
error "Error in Sphinx processing", err = res3.error
|
||||
fail()
|
||||
let (address3, delay3, processedPacket3, status3) = res3.get()
|
||||
let processedSP3 = res3.get()
|
||||
|
||||
if status3 != Exit:
|
||||
error "Processing status should be Exit", status3
|
||||
if processedSP3.status != Exit:
|
||||
error "Processing status should be Exit", status = processedSP3.status
|
||||
fail()
|
||||
|
||||
if processedPacket3 != message.getContent():
|
||||
if processedSP3.messageChunk != message.getContent():
|
||||
error "Packet processing failed"
|
||||
fail()
|
||||
|
||||
@@ -181,9 +183,9 @@ suite "Sphinx Tests":
|
||||
if res.isErr:
|
||||
error "Error in Sphinx processing", err = res.error
|
||||
fail()
|
||||
let (_, _, _, status) = res.get()
|
||||
let invalidMacPkt = res.get()
|
||||
|
||||
if status != InvalidMAC:
|
||||
if invalidMacPkt.status != InvalidMAC:
|
||||
error "Processing status should be InvalidMAC"
|
||||
fail()
|
||||
|
||||
@@ -211,9 +213,9 @@ suite "Sphinx Tests":
|
||||
if res1.isErr:
|
||||
error "Error in Sphinx processing", err = res1.error
|
||||
fail()
|
||||
let (_, _, _, status1) = res1.get()
|
||||
let processedSP1 = res1.get()
|
||||
|
||||
if status1 != Intermediate:
|
||||
if processedSP1.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
@@ -221,9 +223,9 @@ suite "Sphinx Tests":
|
||||
if res2.isErr:
|
||||
error "Error in Sphinx processing", err = res2.error
|
||||
fail()
|
||||
let (_, _, _, status2) = res2.get()
|
||||
let processedSP2 = res2.get()
|
||||
|
||||
if status2 != Duplicate:
|
||||
if processedSP2.status != Duplicate:
|
||||
error "Processing status should be Duplicate"
|
||||
fail()
|
||||
|
||||
@@ -260,18 +262,20 @@ suite "Sphinx Tests":
|
||||
if res1.isErr:
|
||||
error "Error in Sphinx processing", err = res1.error
|
||||
fail()
|
||||
let (address1, delay1, processedPacket1Bytes, status1) = res1.get()
|
||||
let processedSP1 = res1.get()
|
||||
|
||||
if status1 != Intermediate:
|
||||
if processedSP1.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket1Bytes.len != packetSize:
|
||||
if processedSP1.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP1.serializedSphinxPacket.len),
|
||||
expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes)
|
||||
let processedPacket1Res =
|
||||
SphinxPacket.deserialize(processedSP1.serializedSphinxPacket)
|
||||
if processedPacket1Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket1Res.error
|
||||
fail()
|
||||
@@ -281,18 +285,20 @@ suite "Sphinx Tests":
|
||||
if res2.isErr:
|
||||
error "Error in Sphinx processing", err = res2.error
|
||||
fail()
|
||||
let (address2, delay2, processedPacket2Bytes, status2) = res2.get()
|
||||
let processedSP2 = res2.get()
|
||||
|
||||
if status2 != Intermediate:
|
||||
if processedSP2.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket2Bytes.len != packetSize:
|
||||
if processedSP2.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP2.serializedSphinxPacket.len),
|
||||
expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes)
|
||||
let processedPacket2Res =
|
||||
SphinxPacket.deserialize(processedSP2.serializedSphinxPacket)
|
||||
if processedPacket2Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket2Res.error
|
||||
fail()
|
||||
@@ -302,13 +308,13 @@ suite "Sphinx Tests":
|
||||
if res3.isErr:
|
||||
error "Error in Sphinx processing", err = res3.error
|
||||
fail()
|
||||
let (address3, delay3, processedPacket3, status3) = res3.get()
|
||||
let processedSP3 = res3.get()
|
||||
|
||||
if status3 != Exit:
|
||||
if processedSP3.status != Exit:
|
||||
error "Processing status should be Exit"
|
||||
fail()
|
||||
|
||||
if processedPacket3 != paddedMessage:
|
||||
if processedSP3.messageChunk != paddedMessage:
|
||||
error "Packet processing failed"
|
||||
fail()
|
||||
|
||||
@@ -340,18 +346,19 @@ suite "Sphinx Tests":
|
||||
if res1.isErr:
|
||||
error "Error in Sphinx processing", err = res1.error
|
||||
fail()
|
||||
let (address1, delay1, processedPacket1Bytes, status1) = res1.get()
|
||||
let processedSP1 = res1.get()
|
||||
|
||||
if status1 != Intermediate:
|
||||
if processedSP1.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket1Bytes.len != packetSize:
|
||||
if processedSP1.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP1.serializedSphinxPacket.len), expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes)
|
||||
let processedPacket1Res =
|
||||
SphinxPacket.deserialize(processedSP1.serializedSphinxPacket)
|
||||
if processedPacket1Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket1Res.error
|
||||
fail()
|
||||
@@ -361,18 +368,19 @@ suite "Sphinx Tests":
|
||||
if res2.isErr:
|
||||
error "Error in Sphinx processing", err = res2.error
|
||||
fail()
|
||||
let (address2, delay2, processedPacket2Bytes, status2) = res2.get()
|
||||
let processedSP2 = res2.get()
|
||||
|
||||
if status2 != Intermediate:
|
||||
error "Processing status should be Success"
|
||||
if processedSP2.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket2Bytes.len != packetSize:
|
||||
if processedSP2.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP2.serializedSphinxPacket.len), expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes)
|
||||
let processedPacket2Res =
|
||||
SphinxPacket.deserialize(processedSP2.serializedSphinxPacket)
|
||||
if processedPacket2Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket2Res.error
|
||||
fail()
|
||||
@@ -382,13 +390,13 @@ suite "Sphinx Tests":
|
||||
if res3.isErr:
|
||||
error "Error in Sphinx processing", err = res3.error
|
||||
fail()
|
||||
let (address3, delay3, processedPacket3, status3) = res3.get()
|
||||
let processedSP3 = res3.get()
|
||||
|
||||
if status3 != Reply:
|
||||
if processedSP3.status != Reply:
|
||||
error "Processing status should be Reply"
|
||||
fail()
|
||||
|
||||
let msgRes = processReply(surb.key, surb.secret.get(), processedPacket3)
|
||||
let msgRes = processReply(surb.key, surb.secret.get(), processedSP3.delta_prime)
|
||||
if msgRes.isErr:
|
||||
error "Reply processing failed", err = msgRes.error
|
||||
let msg = msgRes.get()
|
||||
@@ -437,9 +445,9 @@ suite "Sphinx Tests":
|
||||
if res.isErr:
|
||||
error "Error in Sphinx processing", err = res.error
|
||||
fail()
|
||||
let (_, _, _, status) = res.get()
|
||||
let processedSP1 = res.get()
|
||||
|
||||
if status != InvalidMAC:
|
||||
if processedSP1.status != InvalidMAC:
|
||||
error "Processing status should be InvalidMAC"
|
||||
fail()
|
||||
|
||||
@@ -472,19 +480,19 @@ suite "Sphinx Tests":
|
||||
if res1.isErr:
|
||||
error "Error in Sphinx processing", err = res1.error
|
||||
fail()
|
||||
let (_, _, _, status1) = res1.get()
|
||||
let processedSP1 = res1.get()
|
||||
|
||||
if status1 != Intermediate:
|
||||
error "Processing status should be Success"
|
||||
if processedSP1.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
let res2 = processSphinxPacket(packet, privateKeys[0], tm)
|
||||
if res2.isErr:
|
||||
error "Error in Sphinx processing", err = res2.error
|
||||
fail()
|
||||
let (_, _, _, status2) = res2.get()
|
||||
let processedSP2 = res2.get()
|
||||
|
||||
if status2 != Duplicate:
|
||||
if processedSP2.status != Duplicate:
|
||||
error "Processing status should be Duplicate"
|
||||
fail()
|
||||
|
||||
@@ -525,18 +533,20 @@ suite "Sphinx Tests":
|
||||
if res1.isErr:
|
||||
error "Error in Sphinx processing", err = res1.error
|
||||
fail()
|
||||
let (address1, delay1, processedPacket1Bytes, status1) = res1.get()
|
||||
let processedSP1 = res1.get()
|
||||
|
||||
if status1 != Intermediate:
|
||||
error "Processing status should be Success"
|
||||
if processedSP1.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket1Bytes.len != packetSize:
|
||||
if processedSP1.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket1Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP1.serializedSphinxPacket.len),
|
||||
expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket1Res = SphinxPacket.deserialize(processedPacket1Bytes)
|
||||
let processedPacket1Res =
|
||||
SphinxPacket.deserialize(processedSP1.serializedSphinxPacket)
|
||||
if processedPacket1Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket1Res.error
|
||||
fail()
|
||||
@@ -546,18 +556,20 @@ suite "Sphinx Tests":
|
||||
if res2.isErr:
|
||||
error "Error in Sphinx processing", err = res2.error
|
||||
fail()
|
||||
let (address2, delay2, processedPacket2Bytes, status2) = res2.get()
|
||||
let processedSP2 = res2.get()
|
||||
|
||||
if status2 != Intermediate:
|
||||
error "Processing status should be Success"
|
||||
if processedSP2.status != Intermediate:
|
||||
error "Processing status should be Intermediate"
|
||||
fail()
|
||||
|
||||
if processedPacket2Bytes.len != packetSize:
|
||||
if processedSP2.serializedSphinxPacket.len != packetSize:
|
||||
error "Packet length is not valid",
|
||||
pkt_len = $(processedPacket2Bytes.len), expected_len = $packetSize
|
||||
pkt_len = $(processedSP2.serializedSphinxPacket.len),
|
||||
expected_len = $packetSize
|
||||
fail()
|
||||
|
||||
let processedPacket2Res = SphinxPacket.deserialize(processedPacket2Bytes)
|
||||
let processedPacket2Res =
|
||||
SphinxPacket.deserialize(processedSP2.serializedSphinxPacket)
|
||||
if processedPacket1Res.isErr:
|
||||
error "Sphinx wrap error", err = processedPacket2Res.error
|
||||
fail()
|
||||
@@ -567,13 +579,13 @@ suite "Sphinx Tests":
|
||||
if res3.isErr:
|
||||
error "Error in Sphinx processing", err = res3.error
|
||||
fail()
|
||||
let (address3, delay3, processedPacket3, status3) = res3.get()
|
||||
let processedSP3 = res3.get()
|
||||
|
||||
if status3 != Reply:
|
||||
if processedSP3.status != Reply:
|
||||
error "Processing status should be Reply"
|
||||
fail()
|
||||
|
||||
let msgRes = processReply(surb.key, surb.secret.get(), processedPacket3)
|
||||
let msgRes = processReply(surb.key, surb.secret.get(), processedSP3.delta_prime)
|
||||
if msgRes.isErr:
|
||||
error "Reply processing failed", err = msgRes.error
|
||||
let msg = msgRes.get()
|
||||
|
||||
Reference in New Issue
Block a user