refactor: make mix protocol agnostic (#59)

This commit is contained in:
richΛrd
2025-07-24 17:47:12 -04:00
committed by GitHub
parent cdd5ab9657
commit ed242006b4
8 changed files with 144 additions and 167 deletions

View File

@@ -11,7 +11,7 @@ jobs:
uses: status-im/nimbus-common-workflow/.github/workflows/common.yml@main
with:
test-command: |
nimble c ./mix/examples/poc_gossipsub.nim
nimble c ./mix/examples/poc_gossipsub_repeated_runs.nim
nimble c ./mix/examples/poc_noresp_ping.nim
nimble c ./examples/poc_gossipsub.nim
nimble c ./examples/poc_gossipsub_repeated_runs.nim
nimble c ./examples/poc_noresp_ping.nim
nim-versions: '["version-2-0", "version-2-2", "devel"]'

View File

@@ -15,5 +15,6 @@ export writeMixNodeInfoToFile
export mixNodes
export getMixNodeInfo
export `new`
export getMaxMessageSizeForCodec
export deleteNodeInfoFolder
export deletePubInfoFolder

View File

@@ -1,18 +1,15 @@
import hashes, chronos, stew/byteutils, results
import libp2p/stream/connection
import ./protocol, ./mix_protocol
import ./mix_protocol
type MixDialer* = proc(
msg: seq[byte],
proto: ProtocolType,
destMultiAddr: Opt[MultiAddress],
destPeerId: PeerId,
msg: seq[byte], codec: string, destMultiAddr: Opt[MultiAddress], destPeerId: PeerId
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true).}
type MixEntryConnection* = ref object of Connection
destMultiAddr: Opt[MultiAddress]
destPeerId: PeerId
proto: ProtocolType
codec: string
mixDialer: MixDialer
method readExactly*(
@@ -39,7 +36,7 @@ method readLp*(
method write*(
self: MixEntryConnection, msg: seq[byte]
): Future[void] {.async: (raises: [CancelledError, LPStreamError], raw: true), public.} =
self.mixDialer(@msg, self.proto, self.destMultiAddr, self.destPeerId)
self.mixDialer(@msg, self.codec, self.destMultiAddr, self.destPeerId)
proc write*(
self: MixEntryConnection, msg: string
@@ -63,7 +60,7 @@ method writeLp*(
buf[0 ..< vbytes.len] = vbytes.toOpenArray(0, vbytes.len - 1)
buf[vbytes.len ..< buf.len] = msg
self.mixDialer(@buf, self.proto, self.destMultiAddr, self.destPeerId)
self.mixDialer(@buf, self.codec, self.destMultiAddr, self.destPeerId)
method writeLp*(
self: MixEntryConnection, msg: string
@@ -101,7 +98,7 @@ proc new*(
let instance = T(
destMultiAddr: destMultiAddr,
destPeerId: destPeerId,
proto: ProtocolType.fromString(codec),
codec: codec,
mixDialer: mixDialer,
)
@@ -120,13 +117,13 @@ proc new*(
): T {.raises: [].} =
var sendDialerFunc = proc(
msg: seq[byte],
proto: ProtocolType,
codec: string,
destMultiAddr: Opt[MultiAddress],
destPeerId: PeerId,
): Future[void] {.async: (raises: [CancelledError, LPStreamError]).} =
try:
await srcMix.anonymizeLocalProtocolSend(
msg, proto, destMultiAddr, destPeerId, exitNodeIsDestination
msg, codec, destMultiAddr, destPeerId, exitNodeIsDestination
)
except CatchableError as e:
error "Error during execution of anonymizeLocalProtocolSend: ", err = e.msg

View File

@@ -3,7 +3,7 @@
import chronos, chronicles, results
import std/[sequtils, sets]
import libp2p/[multiaddress, protocols/pubsub/pubsubpeer, switch]
import ./[entry_connection, mix_protocol, protocol]
import ./[entry_connection, mix_protocol]
const D* = 4 # No. of peers to forward to

View File

@@ -1,69 +1,85 @@
import chronicles, results
import ./[config, protocol, utils]
import ./[config, utils]
import stew/[byteutils, leb128]
import libp2p/protobuf/minprotobuf
type MixMessage* = object
message: seq[byte]
protocol: ProtocolType
message*: seq[byte]
codec*: string
proc initMixMessage*(message: openArray[byte], protocol: ProtocolType): MixMessage =
return MixMessage(message: @message, protocol: protocol)
proc new*(T: typedesc[MixMessage], message: openArray[byte], codec: string): T =
return T(message: @message, codec: codec)
proc getMixMessage*(mixMsg: MixMessage): (seq[byte], ProtocolType) =
return (mixMsg.message, mixMsg.protocol)
proc serialize*(mixMsg: MixMessage): Result[seq[byte], string] =
if mixMsg.codec.len == 0:
return err("serialization failed: codec cannot be empty")
proc serializeMixMessage*(mixMsg: MixMessage): Result[seq[byte], string] =
try:
let
msgBytes = mixMsg.message
protocolBytes = uint16ToBytes(uint16(mixMsg.protocol))
return ok(msgBytes & protocolBytes)
except Exception as e:
error "Failed to serialize MixMessage", err = e.msg
return err("Serialization failed: " & e.msg)
let vbytes = toBytes(mixMsg.codec.len.uint64, Leb128)
if vbytes.len > 2:
return err("serialization failed: codec length exceeds 2 bytes")
proc deserializeMixMessage*(data: openArray[byte]): Result[MixMessage, string] =
try:
let message = data[0 ..^ (protocolTypeSize + 1)]
var buf =
newSeqUninitialized[byte](vbytes.len + mixMsg.codec.len + mixMsg.message.len)
buf[0 ..< vbytes.len] = vbytes.toOpenArray()
buf[vbytes.len ..< mixMsg.codec.len] = mixMsg.codec.toBytes()
buf[vbytes.len + mixMsg.codec.len ..< buf.len] = mixMsg.message
ok(buf)
let res = bytesToUInt16(data[^protocolTypeSize ..^ 1])
if res.isErr:
return err(res.error)
let protocol = ProtocolType(res.get())
proc deserialize*(
T: typedesc[MixMessage], data: openArray[byte]
): Result[MixMessage, string] =
if data.len == 0:
return err("deserialization failed: data is empty")
return ok(MixMessage(message: message, protocol: protocol))
except Exception as e:
error "Failed to deserialize MixMessage", err = e.msg
return err("Deserialization failed: " & e.msg)
var codecLen: int
var varintLen: int
for i in 0 ..< min(data.len, 2):
let parsed = uint16.fromBytes(data[0 ..< i], Leb128)
if parsed.len < 0 or (i == 1 and parsed.len == 0):
return err("deserialization failed: invalid codec length")
proc serializeMixMessageAndDestination*(
varintLen = parsed.len
codecLen = parsed.val.int
if data.len < varintLen + codecLen:
return err("deserialization failed: not enough data")
ok(
T(
codec: string.fromBytes(data[varintLen ..< varintLen + codecLen]),
message: data[varintLen + codecLen ..< data.len],
)
)
# TODO: These are not used anywhere
# TODO: consider changing the `dest` parameter to a multiaddress
proc serializeWithDestination*(
mixMsg: MixMessage, dest: string
): Result[seq[byte], string] =
try:
let
msgBytes = mixMsg.message
protocolBytes = uint16ToBytes(uint16(mixMsg.protocol))
let destBytes = multiAddrToBytes(dest).valueOr:
return err("Error in multiaddress conversion to bytes: " & error)
let destBytes = multiAddrToBytes(dest).valueOr:
return err("Error in multiaddress conversion to bytes: " & error)
if len(destBytes) != addrSize:
error "Destination address must be exactly " & $addrSize & " bytes"
return err("Destination address must be exactly " & $addrSize & " bytes")
if len(destBytes) != addrSize:
error "Destination address must be exactly " & $addrSize & " bytes"
return err("Destination address must be exactly " & $addrSize & " bytes")
var serializedMixMsg = ?mixMsg.serialize()
let oldLen = serializedMixMsg.len
serializedMixMsg.setLen(oldLen + destBytes.len)
copyMem(addr serializedMixMsg[oldLen], unsafeAddr destBytes[0], destBytes.len)
return ok(msgBytes & protocolBytes & destBytes)
except Exception as e:
error "Failed to serialize MixMessage and destination", err = e.msg
return err("Serialization with destination failed: " & e.msg)
return ok(serializedMixMsg)
proc deserializeMixMessageAndDestination*(
data: openArray[byte]
): Result[(seq[byte], string), string] =
try:
let mixMsg = data[0 ..^ (addrSize + 1)]
# TODO: These are not used anywhere
proc deserializeWithDestination*(
T: typedesc[MixMessage], data: openArray[byte]
): Result[(T, string), string] =
if data.len <= addrSize:
return err("Deserialization with destination failed: not enough data")
let dest = bytesToMultiAddr(data[^addrSize ..^ 1]).valueOr:
return err("Error in destination multiaddress conversion to bytes: " & error)
let mixMsg = ?MixMessage.deserialize(data[0 ..^ (addrSize + 1)])
return ok((mixMsg, dest))
except Exception as e:
return err("Deserialization with destination failed: " & e.msg)
let dest = bytesToMultiAddr(data[^addrSize ..^ 1]).valueOr:
return err("Error in destination multiaddress conversion to bytes: " & error)
return ok((mixMsg, dest))

View File

@@ -1,5 +1,5 @@
import chronicles, chronos, sequtils, strutils, os, results
import std/[strformat, sysrand], serialization, metrics
import std/[strformat, sysrand], metrics
import
./[
config, curve25519, exit_connection, fragmentation, mix_message, mix_node, protocol,
@@ -53,7 +53,7 @@ proc cryptoRandomInt(max: int): Result[int, string] =
return ok(int(value mod uint64(max)))
proc handleMixNodeConnection(
mixProto: MixProtocol, conn: Connection, proto: string
mixProto: MixProtocol, conn: Connection, codec: string
) {.async: (raises: [CancelledError]).} =
var receivedBytes: seq[byte]
try:
@@ -75,10 +75,7 @@ proc handleMixNodeConnection(
let (multiAddr, _, mixPrivKey, _, _) = getMixNodeInfo(mixProto.mixNodeInfo)
let processedPktRes = processSphinxPacket(
receivedBytes,
mixPrivKey,
mixProto.tagManager,
not ProtocolType.fromString(proto).destIsExit,
receivedBytes, mixPrivKey, mixProto.tagManager, not codec.destIsExit
)
if processedPktRes.isErr:
error "Failed to process Sphinx packet", err = processedPktRes.error
@@ -100,18 +97,18 @@ proc handleMixNodeConnection(
mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"])
return
let deserializedResult = deserializeMixMessage(unpaddedMsg).valueOr:
let deserialized = MixMessage.deserialize(unpaddedMsg).valueOr:
error "Deserialization failed", err = error
mix_messages_error.inc(labelValues = ["Exit", "INVALID_SPHINX"])
return
let (message, protocol) = getMixMessage(deserializedResult)
trace "Exit node - Received mix message: ", receiver = multiAddr, message = message
trace "Exit node - Received mix message: ",
receiver = multiAddr, message = deserialized.message
if destIsExit(protocol):
let exitConn = MixExitConnection.new(message)
if destIsExit(deserialized.codec):
let exitConn = MixExitConnection.new(deserialized.message)
trace "Received: ", receiver = multiAddr, message = message
await mixProto.pHandler(exitConn, protocol)
await mixProto.pHandler(exitConn, deserialized.codec)
if exitConn != nil:
try:
await exitConn.close()
@@ -157,8 +154,8 @@ proc handleMixNodeConnection(
var destConn: Connection
try:
destConn = await mixProto.switch.dial(peerId, @[locationAddr], $protocol)
await destConn.writeLp(message)
destConn = await mixProto.switch.dial(peerId, @[locationAddr], deserialized.codec)
await destConn.writeLp(deserialized.message)
#TODO: When response is implemented, we can read the response here
await destConn.close()
except CatchableError as e:
@@ -222,17 +219,23 @@ proc handleMixNodeConnection(
mix_messages_error.inc(labelValues = ["Intermediate/Exit", "INVALID_MAC"])
discard
proc getMaxMessageSizeForCodec*(codec: string): Result[int, string] =
let serializedMsg = ?MixMessage.new(@[], codec).serialize()
if serializedMsg.len > dataSize:
return err("cannot encode messages for this codec")
return ok(dataSize - serializedMsg.len)
proc anonymizeLocalProtocolSend*(
mixProto: MixProtocol,
msg: seq[byte],
proto: ProtocolType,
codec: string,
destMultiAddr: Opt[MultiAddress],
destPeerId: PeerId,
exitNodeIsDestination: bool,
) {.async.} =
let mixMsg = initMixMessage(msg, proto)
let mixMsg = MixMessage.new(msg, codec)
let serialized = serializeMixMessage(mixMsg).valueOr:
let serialized = mixMsg.serialize().valueOr:
error "Serialization failed", err = error
mix_messages_error.inc(labelValues = ["Entry", "NON_RECOVERABLE"])
return
@@ -414,10 +417,10 @@ proc new*(
return err("Failed to load mix pub info for index " & $index & " - err: " & error)
var sendHandlerFunc = proc(
conn: Connection, proto: ProtocolType
conn: Connection, codec: string
): Future[void] {.async: (raises: [CancelledError]).} =
try:
await callHandler(switch, conn, proto)
await callHandler(switch, conn, codec)
except CatchableError as e:
error "Error during execution of MixProtocol handler: ", err = e.msg
return

View File

@@ -1,57 +1,22 @@
import chronos, std/enumerate, strutils
import chronos, std/enumerate
import
libp2p/[builders, protocols/ping, protocols/pubsub/gossipsub/types, stream/connection]
import
../examples/protocols/noresp_ping
# TODO: remove this in PR that makes it not necessary to have a ProtocolType enum
const protocolTypeSize* = 2
type ProtocolType* = enum
Ping
GossipSub12
GossipSub11
GossipSub10
NoRespPing
WakuLightPushProtocol
OtherProtocol
proc `$`*(proto: ProtocolType): string =
case proto
of Ping:
PingCodec
of GossipSub12:
GossipSubCodec_12
of GossipSub11:
GossipSubCodec_11
of GossipSub10:
GossipSubCodec_10
of NoRespPing:
NoRespPingCodec
of WakuLightPushProtocol:
"/vac/waku/lightpush/3.0.0"
#TODO: fix this hardcoding, for now doing it as importing codecs from waku causses various build errors.
else:
"other" # Placeholder for other protocols
type ProtocolHandler* = proc(conn: Connection, proto: ProtocolType): Future[void] {.
type ProtocolHandler* = proc(conn: Connection, codec: string): Future[void] {.
async: (raises: [CancelledError])
.}
proc fromString*(T: type ProtocolType, proto: string): ProtocolType =
try:
parseEnum[ProtocolType](proto)
except ValueError:
ProtocolType.OtherProtocol
# TODO: this is temporary while I attempt to extract protocol specific logic from mix
func destIsExit*(proto: ProtocolType): bool =
return not (proto == GossipSub12 or proto == GossipSub11 or proto == GossipSub10)
func destIsExit*(proto: string): bool =
return
not (
proto == GossipSubCodec_12 or proto == GossipSubCodec_11 or
proto == GossipSubCodec_10
)
method callHandler*(
switch: Switch, conn: Connection, proto: ProtocolType
switch: Switch, conn: Connection, codec: string
): Future[void] {.base, async.} =
let codec = $proto
for index, handler in enumerate(switch.ms.handlers):
if codec in handler.protos:
await handler.protocol.handler(conn, codec)

View File

@@ -1,100 +1,95 @@
{.used.}
import chronicles, results, unittest
import ../mix/[mix_message, protocol]
import ../mix/mix_message
import stew/byteutils
# Define test cases
suite "mix_message_tests":
test "serialize_and_deserialize_mix_message":
let
message = "Hello World!"
protocol = ProtocolType.Ping
mixMsg = initMixMessage(cast[seq[byte]](message), protocol)
codec = "/test/codec/1.0.0"
mixMsg = MixMessage.new(message.toBytes(), codec)
let serializedResult = serializeMixMessage(mixMsg)
let serializedResult = mixMsg.serialize()
if serializedResult.isErr:
error "Serialization failed", err = serializedResult.error
fail()
let serialized = serializedResult.get()
let deserializedResult = deserializeMixMessage(serialized)
let deserializedResult = MixMessage.deserialize(serialized)
if deserializedResult.isErr:
error "Deserialization failed", err = deserializedResult.error
fail()
let deserializedMsg = deserializedResult.get()
let (dMessage, dProtocol) = getMixMessage(deserializedMsg)
if message != cast[string](dMessage):
if message != string.fromBytes(deserializedMsg.message):
error "Deserialized message does not match the original",
original = message, deserialized = cast[string](dMessage)
original = message, deserialized = string.fromBytes(deserializedMsg.message)
fail()
if protocol != dProtocol:
error "Deserialized protocol does not match the original",
original = protocol, deserialized = dProtocol
if codec != deserializedMsg.codec:
error "Deserialized codec does not match the original",
original = codec,
deserialized = deserializedMsg.codec,
codeco = cast[seq[byte]](codec),
codeder = cast[seq[byte]](deserializedMsg.codec)
fail()
test "serialize_empty_mix_message":
let
emptyMessage = ""
protocol = ProtocolType.OtherProtocol
mixMsg = initMixMessage(cast[seq[byte]](emptyMessage), protocol)
codec = "/test/codec/1.0.0"
mixMsg = MixMessage.new(emptyMessage.toBytes(), codec)
let serializedResult = serializeMixMessage(mixMsg)
let serializedResult = mixMsg.serialize()
if serializedResult.isErr:
error "Serialization failed", err = serializedResult.error
fail()
let serialized = serializedResult.get()
let deserializedResult = deserializeMixMessage(serialized)
let deserializedResult = MixMessage.deserialize(serialized)
if deserializedResult.isErr:
error "Deserialization failed", err = deserializedResult.error
fail()
let dMixMsg: MixMessage = deserializedResult.get()
let (dMessage, dProtocol) = getMixMessage(dMixMsg)
if emptyMessage != cast[string](dMessage):
if emptyMessage != string.fromBytes(dMixMsg.message):
error "Deserialized message is not empty",
expected = emptyMessage, actual = cast[string](dMessage)
expected = emptyMessage, actual = string.fromBytes(dMixMsg.message)
fail()
if protocol != dProtocol:
error "Deserialized protocol does not match the original",
original = protocol, deserialized = dProtocol
if codec != dMixMsg.codec:
error "Deserialized codec does not match the original",
original = codec, deserialized = dMixMsg.codec
fail()
test "serialize_and_deserialize_mix_message_and_destination":
let
message = "Hello World!"
protocol = ProtocolType.GossipSub12
codec = "/test/codec/1.0.0"
destination =
"/ip4/0.0.0.0/tcp/4242/p2p/16Uiu2HAmFkwLVsVh6gGPmSm9R3X4scJ5thVdKfWYeJsKeVrbcgVC"
mixMsg = initMixMessage(cast[seq[byte]](message), protocol)
mixMsg = MixMessage.new(message.toBytes(), codec)
let serializedResult = serializeMixMessageAndDestination(mixMsg, destination)
let serializedResult = mixMsg.serializeWithDestination(destination)
if serializedResult.isErr:
error "Serialization with destination failed", err = serializedResult.error
fail()
let serialized = serializedResult.get()
let deserializedResult = deserializeMixMessageAndDestination(serialized)
let deserializedResult = MixMessage.deserializeWithDestination(serialized)
if deserializedResult.isErr:
error "Deserialization with destination failed", err = deserializedResult.error
fail()
let (mixMsgBytes, dDest) = deserializedResult.get()
let (dMixMessage, dDest) = deserializedResult.get()
let dMixMsgResult = deserializeMixMessage(mixMsgBytes)
if dMixMsgResult.isErr:
error "Deserialization of MixMessage failed", err = dMixMsgResult.error
fail()
let dMixMsg = dMixMsgResult.get()
let (dMessage, dProtocol) = getMixMessage(dMixMsg)
if message != cast[string](dMessage):
if message != string.fromBytes(dMixMessage.message):
error "Deserialized message does not match the original",
original = message, deserialized = cast[string](dMessage)
original = message, deserialized = string.fromBytes(dMixMessage.message)
fail()
if protocol != dProtocol:
error "Deserialized protocol does not match the original",
original = protocol, deserialized = dProtocol
if codec != dMixMessage.codec:
error "Deserialized codec does not match the original",
original = codec, deserialized = dMixMessage.codec
fail()
if destination != dDest:
error "Deserialized destination does not match the original",