mirror of
https://github.com/vacp2p/mix.git
synced 2026-01-10 14:48:25 -05:00
chore: use new/init and naming of res variables (#60)
- `new` is used for those that return references - `init` for those that return values - Also changed the name of some `*res` variables that were result of `valueOr`, and use `==` for field elements directly - renamed serialize/deserialize procs
This commit is contained in:
@@ -50,7 +50,3 @@ proc multiplyBasePointWithScalars*(
|
||||
for i in 1 ..< scalars.len:
|
||||
Curve25519.mul(res, scalars[i]) # Multiply with each scalar
|
||||
ok(res)
|
||||
|
||||
# Compare two FieldElements
|
||||
proc compareFieldElements*(a, b: FieldElement): bool =
|
||||
a == b
|
||||
|
||||
@@ -36,7 +36,7 @@ method readLp*(
|
||||
method write*(
|
||||
self: MixEntryConnection, msg: seq[byte]
|
||||
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
|
||||
self.mixDialer(@msg, self.codec, self.destMultiAddr, self.destPeerId)
|
||||
self.mixDialer(msg, self.codec, self.destMultiAddr, self.destPeerId)
|
||||
|
||||
proc write*(
|
||||
self: MixEntryConnection, msg: string
|
||||
|
||||
@@ -139,7 +139,7 @@ method closeImpl*(
|
||||
func hash*(self: MixExitConnection): Hash =
|
||||
discard
|
||||
|
||||
proc new*(T: typedesc[MixExitConnection], message: seq[byte]): MixExitConnection =
|
||||
proc new*(T: typedesc[MixExitConnection], message: seq[byte]): T =
|
||||
let instance = T(message: message)
|
||||
|
||||
when defined(libp2p_agents_metrics):
|
||||
|
||||
@@ -10,15 +10,15 @@ type MessageChunk* = object
|
||||
data: seq[byte]
|
||||
seqNo: uint32
|
||||
|
||||
proc initMessageChunk*(
|
||||
paddingLength: uint16, data: seq[byte], seqNo: uint32
|
||||
): MessageChunk =
|
||||
MessageChunk(paddingLength: paddingLength, data: data, seqNo: seqNo)
|
||||
proc init*(
|
||||
T: typedesc[MessageChunk], paddingLength: uint16, data: seq[byte], seqNo: uint32
|
||||
): T =
|
||||
T(paddingLength: paddingLength, data: data, seqNo: seqNo)
|
||||
|
||||
proc getMessageChunk*(msgChunk: MessageChunk): (uint16, seq[byte], uint32) =
|
||||
(msgChunk.paddingLength, msgChunk.data, msgChunk.seqNo)
|
||||
|
||||
proc serializeMessageChunk*(msgChunk: MessageChunk): Result[seq[byte], string] =
|
||||
proc serialize*(msgChunk: MessageChunk): Result[seq[byte], string] =
|
||||
let
|
||||
paddingBytes = uint16ToBytes(msgChunk.paddingLength)
|
||||
seqNoBytes = uint32ToBytes(msgChunk.seqNo)
|
||||
@@ -26,7 +26,7 @@ proc serializeMessageChunk*(msgChunk: MessageChunk): Result[seq[byte], string] =
|
||||
return err("Padded data must be exactly " & $dataSize & " bytes")
|
||||
return ok(paddingBytes & msgChunk.data & seqNoBytes)
|
||||
|
||||
proc deserializeMessageChunk*(data: openArray[byte]): Result[MessageChunk, string] =
|
||||
proc deserialize*(T: typedesc[MessageChunk], data: openArray[byte]): Result[T, string] =
|
||||
if len(data) != messageSize:
|
||||
return err("Data must be exactly " & $messageSize & " bytes")
|
||||
|
||||
@@ -37,14 +37,14 @@ proc deserializeMessageChunk*(data: openArray[byte]): Result[MessageChunk, strin
|
||||
|
||||
let seqNo = bytesToUInt32(data[paddingLengthSize + dataSize ..^ 1]).valueOr:
|
||||
return err("Error in bytes to sequence no. conversion: " & error)
|
||||
ok(MessageChunk(paddingLength: paddingLength, data: @chunk, seqNo: seqNo))
|
||||
ok(T(paddingLength: paddingLength, data: chunk, seqNo: seqNo))
|
||||
|
||||
proc ceilDiv*(a, b: int): int =
|
||||
(a + b - 1) div b
|
||||
|
||||
# Function for padding messages smaller than dataSize
|
||||
proc padMessage*(messageBytes: seq[byte], peerId: PeerId): MessageChunk =
|
||||
var seqNoGen = initSeqNo(peerId)
|
||||
var seqNoGen = SeqNo.init(peerId)
|
||||
seqNoGen.generateSeqNo(messageBytes)
|
||||
|
||||
let paddingLength = uint16(dataSize - len(messageBytes))
|
||||
@@ -68,7 +68,7 @@ proc unpadMessage*(msgChunk: MessageChunk): Result[seq[byte], string] =
|
||||
ok(msgChunk.data[msgChunk.paddingLength ..^ 1])
|
||||
|
||||
proc padAndChunkMessage*(messageBytes: seq[byte], peerId: PeerId): seq[MessageChunk] =
|
||||
var seqNoGen = initSeqNo(peerId)
|
||||
var seqNoGen = SeqNo.init(peerId)
|
||||
seqNoGen.generateSeqNo(messageBytes)
|
||||
|
||||
var chunks: seq[MessageChunk] = @[]
|
||||
@@ -90,7 +90,7 @@ proc padAndChunkMessage*(messageBytes: seq[byte], peerId: PeerId): seq[MessageCh
|
||||
else:
|
||||
chunkData
|
||||
|
||||
let msgChunk = initMessageChunk(paddingLength, paddedData, seqNoGen.getSeqNo())
|
||||
let msgChunk = MessageChunk.init(paddingLength, paddedData, seqNoGen.getSeqNo())
|
||||
chunks.add(msgChunk)
|
||||
|
||||
seqNoGen.incSeqNo()
|
||||
|
||||
@@ -7,7 +7,7 @@ type MixMessage* = object
|
||||
message*: seq[byte]
|
||||
codec*: string
|
||||
|
||||
proc new*(T: typedesc[MixMessage], message: openArray[byte], codec: string): T =
|
||||
proc init*(T: typedesc[MixMessage], message: openArray[byte], codec: string): T =
|
||||
return T(message: @message, codec: codec)
|
||||
|
||||
proc serialize*(mixMsg: MixMessage): Result[seq[byte], string] =
|
||||
|
||||
@@ -87,7 +87,7 @@ proc handleMixNodeConnection(
|
||||
of Exit:
|
||||
mix_messages_recvd.inc(labelValues = ["Exit"])
|
||||
# This is the exit node, forward to destination
|
||||
let msgChunk = deserializeMessageChunk(processedPkt).valueOr:
|
||||
let msgChunk = MessageChunk.deserialize(processedPkt).valueOr:
|
||||
error "Deserialization failed", err = error
|
||||
mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"])
|
||||
return
|
||||
@@ -220,7 +220,7 @@ proc handleMixNodeConnection(
|
||||
discard
|
||||
|
||||
proc getMaxMessageSizeForCodec*(codec: string): Result[int, string] =
|
||||
let serializedMsg = ?MixMessage.new(@[], codec).serialize()
|
||||
let serializedMsg = ?MixMessage.init(@[], codec).serialize()
|
||||
if serializedMsg.len > dataSize:
|
||||
return err("cannot encode messages for this codec")
|
||||
return ok(dataSize - serializedMsg.len)
|
||||
@@ -233,7 +233,7 @@ proc anonymizeLocalProtocolSend*(
|
||||
destPeerId: PeerId,
|
||||
exitNodeIsDestination: bool,
|
||||
) {.async.} =
|
||||
let mixMsg = MixMessage.new(msg, codec)
|
||||
let mixMsg = MixMessage.init(msg, codec)
|
||||
|
||||
let serialized = mixMsg.serialize().valueOr:
|
||||
error "Serialization failed", err = error
|
||||
@@ -320,7 +320,7 @@ proc anonymizeLocalProtocolSend*(
|
||||
#TODO: should we skip and pick a different node here??
|
||||
return
|
||||
|
||||
hop.add(initHop(multiAddrBytes))
|
||||
hop.add(Hop.init(multiAddrBytes))
|
||||
|
||||
# Compute delay
|
||||
let delayMilliSec = cryptoRandomInt(3).valueOr:
|
||||
@@ -329,7 +329,7 @@ proc anonymizeLocalProtocolSend*(
|
||||
return
|
||||
delay.add(uint16ToBytes(uint16(delayMilliSec)))
|
||||
i = i + 1
|
||||
let serializedRes = serializeMessageChunk(paddedMsg).valueOr:
|
||||
let serializedMsgChunk = paddedMsg.serialize().valueOr:
|
||||
error "Failed to serialize padded message", err = error
|
||||
mix_messages_error.inc(labelValues = ["Entry", "NON_RECOVERABLE"])
|
||||
return
|
||||
@@ -342,11 +342,11 @@ proc anonymizeLocalProtocolSend*(
|
||||
error "Failed to convert multiaddress to bytes", err = error
|
||||
mix_messages_error.inc(labelValues = ["Entry", "INVALID_DEST"])
|
||||
return
|
||||
destHop = Opt.some(initHop(destAddrBytes))
|
||||
destHop = Opt.some(Hop.init(destAddrBytes))
|
||||
|
||||
# Wrap in Sphinx packet
|
||||
let sphinxPacket = wrapInSphinxPacket(
|
||||
initMessage(serializedRes), publicKeys, delay, hop, destHop
|
||||
Message.init(serializedMsgChunk), publicKeys, delay, hop, destHop
|
||||
).valueOr:
|
||||
error "Failed to wrap in sphinx packet", err = error
|
||||
mix_messages_error.inc(labelValues = ["Entry", "NON_RECOVERABLE"])
|
||||
@@ -385,20 +385,28 @@ proc anonymizeLocalProtocolSend*(
|
||||
except CatchableError as e:
|
||||
error "Failed to close outgoing stream: ", err = e.msg
|
||||
|
||||
proc createMixProtocol*(
|
||||
method init*(mixProtocol: MixProtocol) {.gcsafe, raises: [].} =
|
||||
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
|
||||
await mixProtocol.handleMixNodeConnection(conn, proto)
|
||||
|
||||
mixProtocol.codecs = @[MixProtocolID]
|
||||
mixProtocol.handler = handle
|
||||
|
||||
proc new*(
|
||||
T: typedesc[MixProtocol],
|
||||
mixNodeInfo: MixNodeInfo,
|
||||
pubNodeInfo: Table[PeerId, MixPubInfo],
|
||||
switch: Switch,
|
||||
tagManager: TagManager,
|
||||
handler: ProtocolHandler,
|
||||
): Result[MixProtocol, string] =
|
||||
let mixProto = new MixProtocol
|
||||
): Result[T, string] =
|
||||
let mixProto = new(T)
|
||||
mixProto.mixNodeInfo = mixNodeInfo
|
||||
mixProto.pubNodeInfo = pubNodeInfo
|
||||
mixProto.switch = switch
|
||||
mixProto.tagManager = tagManager
|
||||
mixProto.pHandler = handler
|
||||
mixProto.init()
|
||||
mixProto.init() # TODO: constructor should probably not call init
|
||||
|
||||
return ok(mixProto)
|
||||
|
||||
@@ -425,15 +433,13 @@ proc new*(
|
||||
error "Error during execution of MixProtocol handler: ", err = e.msg
|
||||
return
|
||||
|
||||
let mixProto = T(
|
||||
mixNodeInfo: mixNodeInfo,
|
||||
pubNodeInfo: pubNodeInfo,
|
||||
switch: switch,
|
||||
tagManager: initTagManager(),
|
||||
pHandler: sendHandlerFunc,
|
||||
)
|
||||
let mixProto =
|
||||
?MixProtocol.new(
|
||||
mixNodeInfo, pubNodeInfo, switch, TagManager.new(), sendHandlerFunc
|
||||
)
|
||||
|
||||
mixProto.init() # TODO: constructor should probably not call init
|
||||
|
||||
mixProto.init()
|
||||
return ok(mixProto)
|
||||
|
||||
# TODO: is this needed
|
||||
@@ -452,17 +458,10 @@ proc initialize*(
|
||||
mixProtocol.mixNodeInfo = localMixNodeInfo
|
||||
mixProtocol.switch = switch
|
||||
mixProtocol.pubNodeInfo = mixNodeTable
|
||||
mixProtocol.tagManager = initTagManager()
|
||||
mixProtocol.tagManager = TagManager.new()
|
||||
|
||||
mixProtocol.init()
|
||||
|
||||
method init*(mixProtocol: MixProtocol) {.gcsafe, raises: [].} =
|
||||
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
|
||||
await mixProtocol.handleMixNodeConnection(conn, proto)
|
||||
|
||||
mixProtocol.codecs = @[MixProtocolID]
|
||||
mixProtocol.handler = handle
|
||||
|
||||
# TODO: is this needed
|
||||
method setNodePool*(
|
||||
mixProtocol: MixProtocol, mixNodeTable: Table[PeerId, MixPubInfo]
|
||||
|
||||
@@ -5,8 +5,8 @@ import ./crypto
|
||||
type SeqNo* = object
|
||||
counter: uint32
|
||||
|
||||
proc initSeqNo*(peerId: PeerId): SeqNo =
|
||||
var seqNo: SeqNo
|
||||
proc init*(T: typedesc[SeqNo], peerId: PeerId): T =
|
||||
var seqNo = T()
|
||||
let peerIdHash = sha256_hash(peerId.data)
|
||||
for i in 0 .. 3:
|
||||
seqNo.counter = seqNo.counter or (uint32(peerIdHash[i]) shl (8 * (3 - i)))
|
||||
|
||||
@@ -6,13 +6,15 @@ type Header* = object
|
||||
Beta: seq[byte]
|
||||
Gamma: seq[byte]
|
||||
|
||||
proc initHeader*(alpha: seq[byte], beta: seq[byte], gamma: seq[byte]): Header =
|
||||
return Header(Alpha: alpha, Beta: beta, Gamma: gamma)
|
||||
proc init*(
|
||||
T: typedesc[Header], alpha: seq[byte], beta: seq[byte], gamma: seq[byte]
|
||||
): T =
|
||||
return T(Alpha: alpha, Beta: beta, Gamma: gamma)
|
||||
|
||||
proc getHeader*(header: Header): (seq[byte], seq[byte], seq[byte]) =
|
||||
(header.Alpha, header.Beta, header.Gamma)
|
||||
|
||||
proc serializeHeader*(header: Header): Result[seq[byte], string] =
|
||||
proc serialize*(header: Header): Result[seq[byte], string] =
|
||||
if len(header.Alpha) != alphaSize:
|
||||
return err("Alpha must be exactly " & $alphaSize & " bytes")
|
||||
if len(header.Beta) != betaSize:
|
||||
@@ -24,43 +26,44 @@ proc serializeHeader*(header: Header): Result[seq[byte], string] =
|
||||
type Message* = object
|
||||
Content: seq[byte]
|
||||
|
||||
proc initMessage*(content: seq[byte]): Message =
|
||||
return Message(Content: content)
|
||||
proc init*(T: typedesc[Message], content: seq[byte]): T =
|
||||
return T(Content: content)
|
||||
|
||||
proc getMessage*(message: Message): seq[byte] =
|
||||
proc getContent*(message: Message): seq[byte] =
|
||||
return message.Content
|
||||
|
||||
proc serializeMessage*(message: Message): Result[seq[byte], string] =
|
||||
proc serialize*(message: Message): Result[seq[byte], string] =
|
||||
if len(message.Content) != messageSize:
|
||||
return err("Message must be exactly " & $(messageSize) & " bytes")
|
||||
var res = newSeq[byte](k) # Prepend k bytes of zero padding
|
||||
res.add(message.Content)
|
||||
return ok(res)
|
||||
|
||||
proc deserializeMessage*(serializedMessage: openArray[byte]): Result[Message, string] =
|
||||
proc deserialize*(
|
||||
T: typedesc[Message], serializedMessage: openArray[byte]
|
||||
): Result[T, string] =
|
||||
if len(serializedMessage) != payloadSize:
|
||||
return err("Serialized message must be exactly " & $payloadSize & " bytes")
|
||||
let content = serializedMessage[k ..^ 1]
|
||||
return ok(Message(Content: content))
|
||||
return ok(T(Content: serializedMessage[k ..^ 1]))
|
||||
|
||||
type Hop* = object
|
||||
MultiAddress: seq[byte]
|
||||
|
||||
proc initHop*(multiAddress: seq[byte]): Hop =
|
||||
return Hop(MultiAddress: multiAddress)
|
||||
proc init*(T: typedesc[Hop], multiAddress: seq[byte]): T =
|
||||
T(MultiAddress: multiAddress)
|
||||
|
||||
proc getHop*(hop: Hop): seq[byte] =
|
||||
return hop.MultiAddress
|
||||
|
||||
proc serializeHop*(hop: Hop): Result[seq[byte], string] =
|
||||
proc serialize*(hop: Hop): Result[seq[byte], string] =
|
||||
if len(hop.MultiAddress) != addrSize:
|
||||
return err("MultiAddress must be exactly " & $addrSize & " bytes")
|
||||
return ok(hop.MultiAddress)
|
||||
|
||||
proc deserializeHop*(data: openArray[byte]): Result[Hop, string] =
|
||||
proc deserialize*(T: typedesc[Hop], data: openArray[byte]): Result[T, string] =
|
||||
if len(data) != addrSize:
|
||||
return err("MultiAddress must be exactly " & $addrSize & " bytes")
|
||||
return ok(Hop(MultiAddress: @data))
|
||||
return ok(T(MultiAddress: @data))
|
||||
|
||||
type RoutingInfo* = object
|
||||
Addr: Hop
|
||||
@@ -68,15 +71,19 @@ type RoutingInfo* = object
|
||||
Gamma: seq[byte]
|
||||
Beta: seq[byte]
|
||||
|
||||
proc initRoutingInfo*(
|
||||
address: Hop, delay: seq[byte], gamma: seq[byte], beta: seq[byte]
|
||||
): RoutingInfo =
|
||||
return RoutingInfo(Addr: address, Delay: delay, Gamma: gamma, Beta: beta)
|
||||
proc init*(
|
||||
T: typedesc[RoutingInfo],
|
||||
address: Hop,
|
||||
delay: seq[byte],
|
||||
gamma: seq[byte],
|
||||
beta: seq[byte],
|
||||
): T =
|
||||
return T(Addr: address, Delay: delay, Gamma: gamma, Beta: beta)
|
||||
|
||||
proc getRoutingInfo*(info: RoutingInfo): (Hop, seq[byte], seq[byte], seq[byte]) =
|
||||
(info.Addr, info.Delay, info.Gamma, info.Beta)
|
||||
|
||||
proc serializeRoutingInfo*(info: RoutingInfo): Result[seq[byte], string] =
|
||||
proc serialize*(info: RoutingInfo): Result[seq[byte], string] =
|
||||
if len(info.Delay) != delaySize:
|
||||
return err("Delay must be exactly " & $delaySize & " bytes")
|
||||
if len(info.Gamma) != gammaSize:
|
||||
@@ -84,21 +91,21 @@ proc serializeRoutingInfo*(info: RoutingInfo): Result[seq[byte], string] =
|
||||
if len(info.Beta) != (((r * (t + 1)) - t) * k):
|
||||
return err("Beta must be exactly " & $(((r * (t + 1)) - t) * k) & " bytes")
|
||||
|
||||
let addrBytes = serializeHop(info.Addr).valueOr:
|
||||
let addrBytes = info.Addr.serialize().valueOr:
|
||||
return err("Serialize hop error: " & error)
|
||||
|
||||
return ok(addrBytes & info.Delay & info.Gamma & info.Beta)
|
||||
|
||||
proc deserializeRoutingInfo*(data: openArray[byte]): Result[RoutingInfo, string] =
|
||||
proc deserialize*(T: typedesc[RoutingInfo], data: openArray[byte]): Result[T, string] =
|
||||
if len(data) != betaSize + ((t + 1) * k):
|
||||
return err("Data must be exactly " & $(betaSize + ((t + 1) * k)) & " bytes")
|
||||
|
||||
let hopRes = deserializeHop(data[0 .. addrSize - 1]).valueOr:
|
||||
let hop = Hop.deserialize(data[0 .. addrSize - 1]).valueOr:
|
||||
return err("Deserialize hop error: " & error)
|
||||
|
||||
return ok(
|
||||
RoutingInfo(
|
||||
Addr: hopRes,
|
||||
Addr: hop,
|
||||
Delay: data[addrSize .. (addrSize + delaySize - 1)],
|
||||
Gamma: data[(addrSize + delaySize) .. (addrSize + delaySize + gammaSize - 1)],
|
||||
Beta:
|
||||
@@ -107,22 +114,22 @@ proc deserializeRoutingInfo*(data: openArray[byte]): Result[RoutingInfo, string]
|
||||
)
|
||||
|
||||
type SphinxPacket* = object
|
||||
Hdr: Header
|
||||
Payload: seq[byte]
|
||||
Hdr*: Header
|
||||
Payload*: seq[byte]
|
||||
|
||||
proc initSphinxPacket*(header: Header, payload: seq[byte]): SphinxPacket =
|
||||
return SphinxPacket(Hdr: header, Payload: payload)
|
||||
proc init*(T: typedesc[SphinxPacket], header: Header, payload: seq[byte]): T =
|
||||
T(Hdr: header, Payload: payload)
|
||||
|
||||
proc getSphinxPacket*(packet: SphinxPacket): (Header, seq[byte]) =
|
||||
(packet.Hdr, packet.Payload)
|
||||
|
||||
proc serializeSphinxPacket*(packet: SphinxPacket): Result[seq[byte], string] =
|
||||
let headerBytes = serializeHeader(packet.Hdr).valueOr:
|
||||
proc serialize*(packet: SphinxPacket): Result[seq[byte], string] =
|
||||
let headerBytes = packet.Hdr.serialize().valueOr:
|
||||
return err("Serialize sphinx packet header error: " & error)
|
||||
|
||||
return ok(headerBytes & packet.Payload)
|
||||
|
||||
proc deserializeSphinxPacket*(data: openArray[byte]): Result[SphinxPacket, string] =
|
||||
proc deserialize*(T: typedesc[SphinxPacket], data: openArray[byte]): Result[T, string] =
|
||||
if len(data) != packetSize:
|
||||
return err("Sphinx packet size must be exactly " & $packetSize & " bytes")
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ proc computeBetaGammaDelta(
|
||||
msg: Message,
|
||||
delay: openArray[seq[byte]],
|
||||
destHop: Opt[Hop],
|
||||
): Result[(seq[byte], seq[byte], seq[byte]), string] =
|
||||
): Result[(seq[byte], seq[byte], seq[byte]), string] = # TODO: name tuples
|
||||
let sLen = s.len
|
||||
var
|
||||
beta: seq[byte]
|
||||
@@ -143,8 +143,8 @@ proc computeBetaGammaDelta(
|
||||
if i == sLen - 1:
|
||||
var padding: seq[byte]
|
||||
if destHop.isSome:
|
||||
let destBytes = serializeHop(destHop.get()).valueOr:
|
||||
return err("Error in destination address serialization: " & error)
|
||||
let destBytes = destHop.get().serialize().valueOr:
|
||||
return err("Error in destination address serialization: " & error)
|
||||
let paddingLength = (((t + 1) * (r - L)) + 2) * k
|
||||
padding = destBytes & delay[i] & newSeq[byte](paddingLength)
|
||||
else:
|
||||
@@ -155,30 +155,24 @@ proc computeBetaGammaDelta(
|
||||
return err("Error in aes: " & error)
|
||||
beta = aesRes & filler
|
||||
|
||||
let serializeRes = serializeMessage(msg).valueOr:
|
||||
let serializedMsg = msg.serialize().valueOr:
|
||||
return err("Message serialization error: " & error)
|
||||
|
||||
let deltaRes = aes_ctr(delta_aes_key, delta_iv, serializeRes)
|
||||
if deltaRes.isErr:
|
||||
return err("Error in aes: " & deltaRes.error)
|
||||
delta = deltaRes.get()
|
||||
delta = aes_ctr(delta_aes_key, delta_iv, serializedMsg).valueOr:
|
||||
return err("Error in aes: " & error)
|
||||
else:
|
||||
let routingInfo = initRoutingInfo(
|
||||
let routingInfo = RoutingInfo.init(
|
||||
hop[i + 1], delay[i + 1], gamma, beta[0 .. (((r * (t + 1)) - t) * k) - 1]
|
||||
)
|
||||
|
||||
let serializeRes = serializeRoutingInfo(routingInfo).valueOr:
|
||||
let serializedRoutingInfo = routingInfo.serialize().valueOr:
|
||||
return err("Routing info serialization error: " & error)
|
||||
|
||||
let betaRes = aes_ctr(beta_aes_key, beta_iv, serializeRes)
|
||||
if betaRes.isErr:
|
||||
return err("Error in aes: " & betaRes.error)
|
||||
beta = betaRes.get()
|
||||
beta = aes_ctr(beta_aes_key, beta_iv, serializedRoutingInfo).valueOr:
|
||||
return err("Error in aes: " & error)
|
||||
|
||||
let deltaRes = aes_ctr(delta_aes_key, delta_iv, delta)
|
||||
if deltaRes.isErr:
|
||||
return err("Error in aes: " & deltaRes.error)
|
||||
delta = deltaRes.get()
|
||||
delta = aes_ctr(delta_aes_key, delta_iv, delta).valueOr:
|
||||
return err("Error in aes: " & error)
|
||||
|
||||
gamma = toSeq(hmac(mac_key, beta))
|
||||
|
||||
@@ -192,45 +186,41 @@ proc wrapInSphinxPacket*(
|
||||
destHop: Opt[Hop],
|
||||
): Result[seq[byte], string] =
|
||||
# Compute alphas and shared secrets
|
||||
let res1 = computeAlpha(publicKeys)
|
||||
if res1.isErr:
|
||||
return err("Error in alpha generation: " & res1.error)
|
||||
let (alpha_0, s) = res1.get()
|
||||
let (alpha_0, s) = computeAlpha(publicKeys).valueOr:
|
||||
return err("Error in alpha generation: " & error)
|
||||
|
||||
# Compute betas, gammas, and deltas
|
||||
let res2 = computeBetaGammaDelta(s, hop, msg, delay, destHop)
|
||||
if res2.isErr:
|
||||
return err("Error in beta, gamma, and delta generation: " & res2.error)
|
||||
let (beta_0, gamma_0, delta_0) = res2.get()
|
||||
let (beta_0, gamma_0, delta_0) = computeBetaGammaDelta(s, hop, msg, delay, destHop).valueOr:
|
||||
return err("Error in beta, gamma, and delta generation: " & error)
|
||||
|
||||
# Serialize sphinx packet
|
||||
let sphinxPacket = initSphinxPacket(initHeader(alpha_0, beta_0, gamma_0), delta_0)
|
||||
let sphinxPacket = SphinxPacket.init(Header.init(alpha_0, beta_0, gamma_0), delta_0)
|
||||
|
||||
let serializeRes = serializeSphinxPacket(sphinxPacket).valueOr:
|
||||
let serialized = sphinxPacket.serialize().valueOr:
|
||||
return err("Sphinx packet serialization error: " & error)
|
||||
|
||||
return ok(serializeRes)
|
||||
return ok(serialized)
|
||||
|
||||
proc processSphinxPacket*(
|
||||
serSphinxPacket: seq[byte],
|
||||
privateKey: FieldElement,
|
||||
tm: var TagManager,
|
||||
isDestEmbedded: bool,
|
||||
): Result[(Hop, seq[byte], seq[byte], ProcessingStatus), string] =
|
||||
): Result[(Hop, seq[byte], seq[byte], ProcessingStatus), string] = # TODO: named touple
|
||||
# Deserialize the Sphinx packet
|
||||
let deserializeRes = deserializeSphinxPacket(serSphinxPacket).valueOr:
|
||||
let sphinxPacket = SphinxPacket.deserialize(serSphinxPacket).valueOr:
|
||||
return err("Sphinx packet deserialization error: " & error)
|
||||
|
||||
let
|
||||
(header, payload) = getSphinxPacket(deserializeRes)
|
||||
(header, payload) = sphinxPacket.getSphinxPacket()
|
||||
(alpha, beta, gamma) = getHeader(header)
|
||||
|
||||
# Compute shared secret
|
||||
let alphaRes = bytesToFieldElement(alpha).valueOr:
|
||||
let alphaFE = bytesToFieldElement(alpha).valueOr:
|
||||
return err("Error in bytes to field element conversion: " & error)
|
||||
|
||||
let
|
||||
s = multiplyPointWithScalars(alphaRes, [privateKey])
|
||||
s = multiplyPointWithScalars(alphaFE, [privateKey])
|
||||
sBytes = fieldElementToBytes(s)
|
||||
|
||||
# Check if the tag has been seen
|
||||
@@ -285,38 +275,41 @@ proc processSphinxPacket*(
|
||||
0
|
||||
|
||||
if B[bOffset .. bOffset + paddingLength - 1] == zeroPadding:
|
||||
let deserializeRes = deserializeMessage(delta_prime).valueOr:
|
||||
let msg = Message.deserialize(delta_prime).valueOr:
|
||||
return err("Message deserialization error: " & error)
|
||||
let msg = getMessage(deserializeRes)
|
||||
|
||||
let content = msg.getContent()
|
||||
|
||||
if isDestEmbedded:
|
||||
let hop = deserializeHop(B[0 .. addrSize - 1]).valueOr:
|
||||
let hop = Hop.deserialize(B[0 .. addrSize - 1]).valueOr:
|
||||
return err(error)
|
||||
return ok((hop, B[addrSize .. ((t * k) - 1)], msg[0 .. messageSize - 1], Exit))
|
||||
return
|
||||
ok((hop, B[addrSize .. ((t * k) - 1)], content[0 .. messageSize - 1], Exit))
|
||||
else:
|
||||
return ok((Hop(), @[], msg[0 .. messageSize - 1], Exit))
|
||||
return ok((Hop(), @[], content[0 .. messageSize - 1], Exit))
|
||||
else:
|
||||
# Extract routing information from B
|
||||
let deserializeRes = deserializeRoutingInfo(B).valueOr:
|
||||
let routingInfo = RoutingInfo.deserialize(B).valueOr:
|
||||
return err("Routing info deserialization error: " & error)
|
||||
|
||||
let (address, delay, gamma_prime, beta_prime) = getRoutingInfo(deserializeRes)
|
||||
let (address, delay, gamma_prime, beta_prime) = routingInfo.getRoutingInfo()
|
||||
|
||||
# Compute alpha
|
||||
let blinder = bytesToFieldElement(sha256_hash(alpha & sBytes)).valueOr:
|
||||
return err("Error in bytes to field element conversion: " & error)
|
||||
|
||||
let alphaRes = bytesToFieldElement(alpha).valueOr:
|
||||
let alphaFE = bytesToFieldElement(alpha).valueOr:
|
||||
return err("Error in bytes to field element conversion: " & error)
|
||||
|
||||
let alpha_prime = multiplyPointWithScalars(alphaRes, [blinder])
|
||||
let alpha_prime = multiplyPointWithScalars(alphaFE, [blinder])
|
||||
|
||||
# Serialize sphinx packet
|
||||
let sphinxPkt = initSphinxPacket(
|
||||
initHeader(fieldElementToBytes(alpha_prime), beta_prime, gamma_prime), delta_prime
|
||||
let sphinxPkt = SphinxPacket.init(
|
||||
Header.init(fieldElementToBytes(alpha_prime), beta_prime, gamma_prime),
|
||||
delta_prime,
|
||||
)
|
||||
|
||||
let serializeRes = serializeSphinxPacket(sphinxPkt).valueOr:
|
||||
let serializedSP = sphinxPkt.serialize().valueOr:
|
||||
return err("Sphinx packet serialization error: " & error)
|
||||
|
||||
return ok((address, delay, serializeRes, Intermediate))
|
||||
return ok((address, delay, serializedSP, Intermediate))
|
||||
|
||||
@@ -5,8 +5,8 @@ type TagManager* = ref object
|
||||
lock: Lock
|
||||
seenTags: Table[FieldElement, bool]
|
||||
|
||||
proc initTagManager*(): TagManager =
|
||||
let tm = new(TagManager)
|
||||
proc new*(T: typedesc[TagManager]): T =
|
||||
let tm = T()
|
||||
tm.seenTags = initTable[FieldElement, bool]()
|
||||
initLock(tm.lock)
|
||||
return tm
|
||||
|
||||
@@ -41,7 +41,7 @@ suite "curve25519_tests":
|
||||
fail()
|
||||
let derivedPublicKey = derivedPublicKeyResult.get()
|
||||
|
||||
if not compareFieldElements(publicKey, derivedPublicKey):
|
||||
if publicKey != derivedPublicKey:
|
||||
error "Public keydoes not match derived key",
|
||||
publickey = publicKey, derivedkey = derivedPublicKey
|
||||
fail()
|
||||
@@ -69,6 +69,6 @@ suite "curve25519_tests":
|
||||
intermediate = public(x2)
|
||||
res2 = multiplyPointWithScalars(intermediate, @[x1])
|
||||
|
||||
if not compareFieldElements(res1, res2):
|
||||
if res1 != res2:
|
||||
error "Field element operations must be commutative", res1 = res1, res2 = res2
|
||||
fail()
|
||||
|
||||
@@ -14,13 +14,13 @@ suite "Fragmentation":
|
||||
chunks = padAndChunkMessage(message, peerId)
|
||||
(paddingLength, data, seqNo) = getMessageChunk(chunks[0])
|
||||
|
||||
let serializedRes = serializeMessageChunk(chunks[0])
|
||||
let serializedRes = chunks[0].serialize()
|
||||
if serializedRes.isErr:
|
||||
error "Serialization error", err = serializedRes.error
|
||||
fail()
|
||||
let serialized = serializedRes.get()
|
||||
|
||||
let deserializedRes = deserializeMessageChunk(serialized)
|
||||
let deserializedRes = MessageChunk.deserialize(serialized)
|
||||
if deserializedRes.isErr:
|
||||
error "Deserialization error", err = deserializedRes.error
|
||||
fail()
|
||||
|
||||
@@ -10,7 +10,7 @@ suite "mix_message_tests":
|
||||
let
|
||||
message = "Hello World!"
|
||||
codec = "/test/codec/1.0.0"
|
||||
mixMsg = MixMessage.new(message.toBytes(), codec)
|
||||
mixMsg = MixMessage.init(message.toBytes(), codec)
|
||||
|
||||
let serializedResult = mixMsg.serialize()
|
||||
if serializedResult.isErr:
|
||||
@@ -40,7 +40,7 @@ suite "mix_message_tests":
|
||||
let
|
||||
emptyMessage = ""
|
||||
codec = "/test/codec/1.0.0"
|
||||
mixMsg = MixMessage.new(emptyMessage.toBytes(), codec)
|
||||
mixMsg = MixMessage.init(emptyMessage.toBytes(), codec)
|
||||
|
||||
let serializedResult = mixMsg.serialize()
|
||||
if serializedResult.isErr:
|
||||
@@ -69,7 +69,7 @@ suite "mix_message_tests":
|
||||
codec = "/test/codec/1.0.0"
|
||||
destination =
|
||||
"/ip4/0.0.0.0/tcp/4242/p2p/16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC"
|
||||
mixMsg = MixMessage.new(message.toBytes(), codec)
|
||||
mixMsg = MixMessage.init(message.toBytes(), codec)
|
||||
|
||||
let serializedResult = mixMsg.serializeWithDestination(destination)
|
||||
if serializedResult.isErr:
|
||||
|
||||
@@ -102,12 +102,12 @@ suite "Mix Node Tests":
|
||||
multiaddr = fMultiAddr, original = multiAddr
|
||||
fail()
|
||||
|
||||
if not compareFieldElements(fMixPubKey, mixPubKey):
|
||||
if fMixPubKey != mixPubKey:
|
||||
error "Mix public key does not match original mix public key",
|
||||
pubkey = fMixPubKey, original = mixPubKey
|
||||
fail()
|
||||
|
||||
if not compareFieldElements(fMixPrivKey, mixPrivKey):
|
||||
if fMixPrivKey != mixPrivKey:
|
||||
error "Mix private key does not match original mix private key",
|
||||
privkey = fMixPrivKey, original = mixPrivKey
|
||||
fail()
|
||||
@@ -167,12 +167,12 @@ suite "Mix Node Tests":
|
||||
multiaddr = rMultiAddr, original = multiAddr
|
||||
fail()
|
||||
|
||||
if not compareFieldElements(rMixPubKey, mixPubKey):
|
||||
if rMixPubKey != mixPubKey:
|
||||
error "Mix public key does not match original mix public key",
|
||||
pubkey = rMixPubKey, original = mixPubKey
|
||||
fail()
|
||||
|
||||
if not compareFieldElements(rMixPrivKey, mixPrivKey):
|
||||
if rMixPrivKey != mixPrivKey:
|
||||
error "Mix private key does not match original mix private key",
|
||||
privkey = rMixPrivKey, original = mixPrivKey
|
||||
fail()
|
||||
@@ -215,7 +215,7 @@ suite "Mix Node Tests":
|
||||
multiaddr = rMultiAddr, original = multiAddr
|
||||
fail()
|
||||
|
||||
if not compareFieldElements(rMixPubKey, mixPubKey):
|
||||
if rMixPubKey != mixPubKey:
|
||||
error "Mix public key does not match original mix public key",
|
||||
pubkey = rMixPubKey, original = mixPubKey
|
||||
fail()
|
||||
|
||||
@@ -10,7 +10,7 @@ suite "Sequence Number Generator":
|
||||
let
|
||||
peerId =
|
||||
PeerId.init("16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC").get()
|
||||
seqNo = initSeqNo(peerId)
|
||||
seqNo = SeqNo.init(peerId)
|
||||
if seqNo.counter == 0:
|
||||
error "Sequence number initialization failed", counter = seqNo.counter
|
||||
fail()
|
||||
@@ -21,7 +21,7 @@ suite "Sequence Number Generator":
|
||||
PeerId.init("16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC").get()
|
||||
msg1 = @[byte 1, 2, 3]
|
||||
msg2 = @[byte 4, 5, 6]
|
||||
var seqNo = initSeqNo(peerId)
|
||||
var seqNo = SeqNo.init(peerId)
|
||||
|
||||
generateSeqNo(seqNo, msg1)
|
||||
let seqNo1 = seqNo.counter
|
||||
@@ -39,7 +39,7 @@ suite "Sequence Number Generator":
|
||||
peerId =
|
||||
PeerId.init("16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC").get()
|
||||
msg = @[byte 1, 2, 3]
|
||||
var seqNo = initSeqNo(peerId)
|
||||
var seqNo = SeqNo.init(peerId)
|
||||
|
||||
generateSeqNo(seqNo, msg)
|
||||
let seqNo1 = seqNo.counter
|
||||
@@ -61,8 +61,8 @@ suite "Sequence Number Generator":
|
||||
PeerId.init("16Uiu2HAm6WNzw8AssyPscYYi8x1bY5wXyQrGTShRH75bh5dPCjBQ").get()
|
||||
|
||||
var
|
||||
seqNo1 = initSeqNo(peerId1)
|
||||
seqNo2 = initSeqNo(peerId2)
|
||||
seqNo1 = SeqNo.init(peerId1)
|
||||
seqNo2 = SeqNo.init(peerId2)
|
||||
|
||||
if seqNo1.counter == seqNo2.counter:
|
||||
error "Sequence numbers for different peer IDs should be different",
|
||||
@@ -72,7 +72,7 @@ suite "Sequence Number Generator":
|
||||
test "increment_seq_no":
|
||||
let peerId =
|
||||
PeerId.init("16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC").get()
|
||||
var seqNo: SeqNo = initSeqNo(peerId)
|
||||
var seqNo: SeqNo = SeqNo.init(peerId)
|
||||
let initialCounter = seqNo.counter
|
||||
|
||||
incSeqNo(seqNo)
|
||||
@@ -85,7 +85,7 @@ suite "Sequence Number Generator":
|
||||
test "seq_no_wraps_around_at_max_value":
|
||||
let peerId =
|
||||
PeerId.init("16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC").get()
|
||||
var seqNo = initSeqNo(peerId)
|
||||
var seqNo = SeqNo.init(peerId)
|
||||
seqNo.counter = high(uint32) - 1
|
||||
if seqNo.counter != high(uint32) - 1:
|
||||
error "Sequence number must be max value",
|
||||
@@ -101,7 +101,7 @@ suite "Sequence Number Generator":
|
||||
let peerId =
|
||||
PeerId.init("16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC").get()
|
||||
var
|
||||
seqNo = initSeqNo(peerId)
|
||||
seqNo = SeqNo.init(peerId)
|
||||
seenValues = initHashSet[uint32]()
|
||||
|
||||
for i in 0 ..< 10000:
|
||||
|
||||
@@ -6,11 +6,11 @@ import ../mix/[config, serialization]
|
||||
# Define test cases
|
||||
suite "serialization_tests":
|
||||
test "serialize_and_deserialize_header":
|
||||
let header = initHeader(
|
||||
let header = Header.init(
|
||||
newSeq[byte](alphaSize), newSeq[byte](betaSize), newSeq[byte](gammaSize)
|
||||
)
|
||||
|
||||
let serializedRes = serializeHeader(header)
|
||||
let serializedRes = header.serialize()
|
||||
if serializedRes.isErr:
|
||||
error "Failed to serialize header", err = serializedRes.error
|
||||
fail()
|
||||
@@ -22,34 +22,34 @@ suite "serialization_tests":
|
||||
fail()
|
||||
|
||||
test "serialize_and_deserialize_message":
|
||||
let message = initMessage(newSeq[byte](messageSize))
|
||||
let message = Message.init(newSeq[byte](messageSize))
|
||||
|
||||
let serializedRes = serializeMessage(message)
|
||||
let serializedRes = message.serialize()
|
||||
if serializedRes.isErr:
|
||||
error "Failed to serialize message", err = serializedRes.error
|
||||
fail()
|
||||
let serialized = serializedRes.get()
|
||||
|
||||
let deserializedRes = deserializeMessage(serialized)
|
||||
let deserializedRes = Message.deserialize(serialized)
|
||||
if deserializedRes.isErr:
|
||||
error "Failed to deserialize message", err = deserializedRes.error
|
||||
fail()
|
||||
let deserialized = deserializedRes.get()
|
||||
|
||||
if getMessage(message) != getMessage(deserialized):
|
||||
if getContent(message) != getContent(deserialized):
|
||||
error "Deserialized message does not match the original message"
|
||||
fail()
|
||||
|
||||
test "serialize_and_deserialize_hop":
|
||||
let hop = initHop(newSeq[byte](addrSize))
|
||||
let hop = Hop.init(newSeq[byte](addrSize))
|
||||
|
||||
let serializedRes = serializeHop(hop)
|
||||
let serializedRes = hop.serialize()
|
||||
if serializedRes.isErr:
|
||||
error "Failed to serialize hop", err = serializedRes.error
|
||||
fail()
|
||||
let serialized = serializedRes.get()
|
||||
|
||||
let deserializedRes = deserializeHop(serialized)
|
||||
let deserializedRes = Hop.deserialize(serialized)
|
||||
if deserializedRes.isErr:
|
||||
error "Failed to deserialize hop", err = deserializedRes.error
|
||||
fail()
|
||||
@@ -60,14 +60,14 @@ suite "serialization_tests":
|
||||
fail()
|
||||
|
||||
test "serialize_and_deserialize_routing_info":
|
||||
let routingInfo = initRoutingInfo(
|
||||
initHop(newSeq[byte](addrSize)),
|
||||
let routingInfo = RoutingInfo.init(
|
||||
Hop.init(newSeq[byte](addrSize)),
|
||||
newSeq[byte](delaySize),
|
||||
newSeq[byte](gammaSize),
|
||||
newSeq[byte](((r * (t + 1)) - t) * k),
|
||||
)
|
||||
|
||||
let serializedRes = serializeRoutingInfo(routingInfo)
|
||||
let serializedRes = routingInfo.serialize()
|
||||
if serializedRes.isErr:
|
||||
error "Failed to serialize routing info", err = serializedRes.error
|
||||
fail()
|
||||
@@ -77,7 +77,7 @@ suite "serialization_tests":
|
||||
suffixLength = (t + 1) * k
|
||||
suffix = newSeq[byte](suffixLength)
|
||||
|
||||
let deserializedRes = deserializeRoutingInfo(serialized & suffix)
|
||||
let deserializedRes = RoutingInfo.deserialize(serialized & suffix)
|
||||
if deserializedRes.isErr:
|
||||
error "Failed to deserialize routing info", err = deserializedRes.error
|
||||
fail()
|
||||
@@ -105,19 +105,19 @@ suite "serialization_tests":
|
||||
|
||||
test "serialize_and_deserialize_sphinx_packet":
|
||||
let
|
||||
header = initHeader(
|
||||
header = Header.init(
|
||||
newSeq[byte](alphaSize), newSeq[byte](betaSize), newSeq[byte](gammaSize)
|
||||
)
|
||||
payload = newSeq[byte](payloadSize)
|
||||
packet = initSphinxPacket(header, payload)
|
||||
packet = SphinxPacket.init(header, payload)
|
||||
|
||||
let serializedRes = serializeSphinxPacket(packet)
|
||||
let serializedRes = packet.serialize()
|
||||
if serializedRes.isErr:
|
||||
error "Failed to serialize sphinx packet", err = serializedRes.error
|
||||
fail()
|
||||
let serialized = serializedRes.get()
|
||||
|
||||
let deserializedRes = deserializeSphinxPacket(serialized)
|
||||
let deserializedRes = SphinxPacket.deserialize(serialized)
|
||||
if deserializedRes.isErr:
|
||||
error "Failed to deserialize sphinx packet", err = deserializedRes.error
|
||||
fail()
|
||||
|
||||
@@ -42,12 +42,12 @@ proc createDummyData(): (
|
||||
|
||||
hops =
|
||||
@[
|
||||
initHop(newSeq[byte](addrSize)),
|
||||
initHop(newSeq[byte](addrSize)),
|
||||
initHop(newSeq[byte](addrSize)),
|
||||
Hop.init(newSeq[byte](addrSize)),
|
||||
Hop.init(newSeq[byte](addrSize)),
|
||||
Hop.init(newSeq[byte](addrSize)),
|
||||
]
|
||||
|
||||
message = initMessage(newSeq[byte](messageSize))
|
||||
message = Message.init(newSeq[byte](messageSize))
|
||||
|
||||
return (message, privateKeys, publicKeys, delay, hops)
|
||||
|
||||
@@ -56,7 +56,7 @@ suite "Sphinx Tests":
|
||||
var tm: TagManager
|
||||
|
||||
setup:
|
||||
tm = initTagManager()
|
||||
tm = TagManager.new()
|
||||
|
||||
teardown:
|
||||
clearTags(tm)
|
||||
@@ -114,7 +114,7 @@ suite "Sphinx Tests":
|
||||
error "Processing status should be Exit"
|
||||
fail()
|
||||
|
||||
let processedMessage = initMessage(processedPacket3)
|
||||
let processedMessage = Message.init(processedPacket3)
|
||||
if processedMessage != message:
|
||||
error "Packet processing failed"
|
||||
fail()
|
||||
@@ -199,7 +199,7 @@ suite "Sphinx Tests":
|
||||
let paddedMessage = padMessage(message, messageSize)
|
||||
|
||||
let packetRes = wrapInSphinxPacket(
|
||||
initMessage(paddedMessage), publicKeys, delay, hops, Opt.none(Hop)
|
||||
Message.init(paddedMessage), publicKeys, delay, hops, Opt.none(Hop)
|
||||
)
|
||||
if packetRes.isErr:
|
||||
error "Sphinx wrap error", err = packetRes.error
|
||||
|
||||
@@ -7,7 +7,7 @@ suite "tag_manager_tests":
|
||||
var tm: TagManager
|
||||
|
||||
setup:
|
||||
tm = initTagManager()
|
||||
tm = TagManager.new()
|
||||
|
||||
teardown:
|
||||
clearTags(tm)
|
||||
|
||||
Reference in New Issue
Block a user