diff --git a/mix/sphinx.nim b/mix/sphinx.nim index bb4c382..fdb439b 100644 --- a/mix/sphinx.nim +++ b/mix/sphinx.nim @@ -1,11 +1,12 @@ import results, sequtils -import std/math +import nimcrypto/sysrand import ./[config, crypto, curve25519, serialization, tag_manager] # Define possible outcomes of processing a Sphinx packet type ProcessingStatus* = enum Exit # Packet processed successfully at exit Success # Packet processed successfully + Reply # Packet processed successfully at exit; a reply message Duplicate # Packet was discarded due to duplicate tag InvalidMAC # Packet was discarded due to MAC verification failure @@ -112,34 +113,26 @@ proc generateRandomDelay(): seq[byte] = return toseq(delayBytes) ]# -# Function to compute betas, gammas, and deltas -proc computeBetaGammaDelta( - s: seq[seq[byte]], - hop: openArray[Hop], - msg: Message, - delay: openArray[seq[byte]], - destHop: Hop, -): Result[(seq[byte], seq[byte], seq[byte]), string] = +# Function to compute betas, and gammas +proc computeBetaGamma( + s: seq[seq[byte]], hop: openArray[Hop], delay: openArray[seq[byte]], destHop: Hop +): Result[(seq[byte], seq[byte]), string] = let sLen = s.len var beta: seq[byte] gamma: seq[byte] - delta: seq[byte] # Compute filler strings let filler = computeFillerStrings(s).valueOr: return err("Error in filler generation: " & error) for i in countdown(sLen - 1, 0): - # Derive AES keys, MAC key, and IVs + # Derive AES key, MAC key, and IV let beta_aes_key = kdf(deriveKeyMaterial("aes_key", s[i])) mac_key = kdf(deriveKeyMaterial("mac_key", s[i])) beta_iv = kdf(deriveKeyMaterial("iv", s[i])) - delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", s[i])) - delta_iv = kdf(deriveKeyMaterial("delta_iv", s[i])) - # Compute Beta and Gamma if i == sLen - 1: let destBytes = serializeHop(destHop).valueOr: @@ -151,14 +144,6 @@ proc computeBetaGammaDelta( let aesRes = aes_ctr(beta_aes_key, beta_iv, destPadding).valueOr: return err("Error in aes: " & error) beta = aesRes & filler - - let serializeRes = serializeMessage(msg).valueOr: - return err("Message serialization error: " & error) - - let deltaRes = aes_ctr(delta_aes_key, delta_iv, serializeRes) - if deltaRes.isErr: - return err("Error in aes: " & deltaRes.error) - delta = deltaRes.get() else: let routingInfo = initRoutingInfo( hop[i + 1], delay[i + 1], gamma, beta[0 .. (((r * (t + 1)) - t) * k) - 1] @@ -172,14 +157,111 @@ proc computeBetaGammaDelta( return err("Error in aes: " & betaRes.error) beta = betaRes.get() + gamma = toseq(hmac(mac_key, beta)) + + return ok((beta, gamma)) + +# Function to compute deltas +proc computeDelta(s: seq[seq[byte]], msg: Message): Result[seq[byte], string] = + let sLen = s.len + var delta: seq[byte] + + for i in countdown(sLen - 1, 0): + # Derive AES key and IV + let + delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", s[i])) + delta_iv = kdf(deriveKeyMaterial("delta_iv", s[i])) + + # Compute Delta + if i == sLen - 1: + let serializeRes = serializeMessage(msg).valueOr: + return err("Message serialization error: " & error) + + let deltaRes = aes_ctr(delta_aes_key, delta_iv, serializeRes) + if deltaRes.isErr: + return err("Error in aes: " & deltaRes.error) + delta = deltaRes.get() + else: let deltaRes = aes_ctr(delta_aes_key, delta_iv, delta) if deltaRes.isErr: return err("Error in aes: " & deltaRes.error) delta = deltaRes.get() - gamma = toseq(hmac(mac_key, beta)) + return ok(delta) - return ok((beta, gamma, delta)) +proc createSURB*( + publicKeys: openArray[FieldElement], + delay: openArray[seq[byte]], + hop: openArray[Hop], + destHop: Hop, +): Result[(Hop, Header, seq[seq[byte]], seq[byte]), string] = + # Compute alpha and shared secrets + let res1 = computeAlpha(publicKeys) + if res1.isErr: + return err("Error in alpha generation: " & res1.error) + let (alpha_0, s) = res1.get() + + # Compute beta and gamma + let res2 = computeBetaGamma(s, hop, delay, destHop) + if res2.isErr: + return err("Error in beta and gamma generation: " & res2.error) + let (beta_0, gamma_0) = res2.get() + + # Generate key + var key = newSeq[byte](k) + discard randomBytes(key) + + return ok((hop[0], initHeader(alpha_0, beta_0, gamma_0), s, key)) + +proc useSURB*(header: Header, key: seq[byte], msg: Message): Result[seq[byte], string] = + # Derive AES key and IV + let + delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", key)) + delta_iv = kdf(deriveKeyMaterial("delta_iv", key)) + + # Compute Delta + let serializeMsg = serializeMessage(msg).valueOr: + return err("Message serialization error: " & error) + + let delta = aes_ctr(delta_aes_key, delta_iv, serializeMsg).valueOr: + return err("Error in aes: " & error) + + # Serialize sphinx packet + let sphinxPacket = initSphinxPacket(header, delta) + + let serializeRes = serializeSphinxPacket(sphinxPacket).valueOr: + return err("Sphinx packet serialization error: " & error) + + return ok(serializeRes) + +proc processReply*( + key: seq[byte], s: seq[seq[byte]], delta_prime: seq[byte] +): Result[seq[byte], string] = + var delta = delta_prime[0 ..^ 1] + + for i in 0 .. s.len: + var key_prime: seq[byte] + + if i == 0: + key_prime = key + else: + key_prime = s[i - 1] + + # Derive AES key and IV + let + delta_aes_key = kdf(deriveKeyMaterial("delta_aes_key", key_prime)) + delta_iv = kdf(deriveKeyMaterial("delta_iv", key_prime)) + + let deltaRes = aes_ctr(delta_aes_key, delta_iv, delta) + if deltaRes.isErr: + return err("Error in aes: " & deltaRes.error) + delta = deltaRes.get() + + let deserializeRes = deserializeMessage(delta).valueOr: + return err("Message deserialization error: " & error) + let msg = getMessage(deserializeRes) + + return ok(msg) proc wrapInSphinxPacket*( msg: Message, @@ -188,17 +270,21 @@ proc wrapInSphinxPacket*( hop: openArray[Hop], destHop: Hop, ): Result[seq[byte], string] = - # Compute alphas and shared secrets + # Compute alpha and shared secrets let res1 = computeAlpha(publicKeys) if res1.isErr: return err("Error in alpha generation: " & res1.error) let (alpha_0, s) = res1.get() - # Compute betas, gammas, and deltas - let res2 = computeBetaGammaDelta(s, hop, msg, delay, destHop) + # Compute beta and gamma + let res2 = computeBetaGamma(s, hop, delay, destHop) if res2.isErr: - return err("Error in beta, gamma, and delta generation: " & res2.error) - let (beta_0, gamma_0, delta_0) = res2.get() + return err("Error in beta and gamma generation: " & res2.error) + let (beta_0, gamma_0) = res2.get() + + # Compute delta + let delta_0 = computeDelta(s, msg).valueOr: + return err("Error in delta generation: " & error) # Serialize sphinx packet let sphinxPacket = initSphinxPacket(initHeader(alpha_0, beta_0, gamma_0), delta_0) @@ -268,13 +354,17 @@ proc processSphinxPacket*( zeroPadding = newSeq[byte](paddingLength) if B[(t * k) .. (t * k) + paddingLength - 1] == zeroPadding: - let deserializeRes = deserializeMessage(delta_prime).valueOr: - return err("Message deserialization error: " & error) - let msg = getMessage(deserializeRes) - let hop = deserializeHop(B[0 .. addrSize - 1]).valueOr: return err(error) - return ok((hop, B[addrSize .. ((t * k) - 1)], msg[0 .. messageSize - 1], Exit)) + + if delta_prime[0 .. (k - 1)] == newSeq[byte](k): + let deserializeRes = deserializeMessage(delta_prime).valueOr: + return err("Message deserialization error: " & error) + let msg = getMessage(deserializeRes) + + return ok((hop, B[addrSize .. ((t * k) - 1)], msg[0 .. messageSize - 1], Exit)) + else: + return ok((hop, B[addrSize .. ((t * k) - 1)], delta_prime, Reply)) else: # Extract routing information from B let deserializeRes = deserializeRoutingInfo(B).valueOr: