From a19f515fe558cb8ed33a064fbfab74f4a77b59ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Seshanth=2ES=F0=9F=90=BA?= <35675963+seshanthS@users.noreply.github.com> Date: Thu, 20 Feb 2025 07:07:46 +0530 Subject: [PATCH] Hotfix/audit rsapss (#144) --- .../crypto/signature/rsapss/rsapss3.circom | 5 + .../signature/rsapss/rsapss65537.circom | 5 + .../crypto/signature/rsapss/validate.circom | 45 ++++ .../tests/utils/generateMockInputsRsaPss.ts | 47 ++++ circuits/tests/utils/rsapss.test.ts | 151 +++++++++++-- circuits/tests/utils/testcase/rsapss.ts | 211 ++++++++++++++++++ 6 files changed, 444 insertions(+), 20 deletions(-) create mode 100644 circuits/circuits/utils/crypto/signature/rsapss/validate.circom create mode 100644 circuits/tests/utils/testcase/rsapss.ts diff --git a/circuits/circuits/utils/crypto/signature/rsapss/rsapss3.circom b/circuits/circuits/utils/crypto/signature/rsapss/rsapss3.circom index 27ad5ac3d..d4a02ce1c 100644 --- a/circuits/circuits/utils/crypto/signature/rsapss/rsapss3.circom +++ b/circuits/circuits/utils/crypto/signature/rsapss/rsapss3.circom @@ -5,6 +5,7 @@ include "./mgf1.circom"; include "../../bitify/gates.circom"; include "../../hasher/hash.circom"; include "../FpPowMod.circom"; +include "./validate.circom"; /* * RSA-PSS (Probabilistic Signature Scheme) Signature Verification @@ -70,6 +71,10 @@ template VerifyRsaPss3Sig(CHUNK_SIZE, CHUNK_NUMBER, SALT_LEN, HASH_TYPE, KEY_LEN signal eM[EM_LEN]; signal eMsgInBits[EM_LEN_BITS]; + + component validateRsaPss = ValidateRsaPss(CHUNK_SIZE, CHUNK_NUMBER, KEY_LENGTH); + validateRsaPss.pubkey <== pubkey; + validateRsaPss.signature <== signature; //computing encoded message component bigPow = FpPow3Mod(CHUNK_SIZE, CHUNK_NUMBER); diff --git a/circuits/circuits/utils/crypto/signature/rsapss/rsapss65537.circom b/circuits/circuits/utils/crypto/signature/rsapss/rsapss65537.circom index 67e61b03e..304650782 100644 --- a/circuits/circuits/utils/crypto/signature/rsapss/rsapss65537.circom +++ b/circuits/circuits/utils/crypto/signature/rsapss/rsapss65537.circom @@ -5,6 +5,7 @@ include "./mgf1.circom"; include "../../bitify/gates.circom"; include "../../hasher/hash.circom"; include "../FpPowMod.circom"; +include "./validate.circom"; /* * RSA-PSS (Probabilistic Signature Scheme) Signature Verification @@ -73,6 +74,10 @@ template VerifyRsaPss65537Sig(CHUNK_SIZE, CHUNK_NUMBER, SALT_LEN, HASH_TYPE, KEY signal eM[EM_LEN]; signal eMsgInBits[EM_LEN_BITS]; + component validateRsaPss = ValidateRsaPss(CHUNK_SIZE, CHUNK_NUMBER, KEY_LENGTH); + validateRsaPss.pubkey <== pubkey; + validateRsaPss.signature <== signature; + //computing encoded message component bigPow = FpPow65537Mod(CHUNK_SIZE, CHUNK_NUMBER); for (var i = 0; i < CHUNK_NUMBER; i++) { diff --git a/circuits/circuits/utils/crypto/signature/rsapss/validate.circom b/circuits/circuits/utils/crypto/signature/rsapss/validate.circom new file mode 100644 index 000000000..570329810 --- /dev/null +++ b/circuits/circuits/utils/crypto/signature/rsapss/validate.circom @@ -0,0 +1,45 @@ +pragma circom 2.1.6; + +include "../../bigInt/bigInt.circom"; +include "@openpassport/zk-email-circuits/lib/bigint.circom"; + +/// @notice Validates the RSA-PSS signature format +/// @dev Checks that the signature and public key are within the modulus length. +/// @param CHUNK_SIZE Size of each chunk in bits +/// @param CHUNK_NUMBER Number of chunks in modulus +/// @param KEY_LENGTH RSA key length (modulus length) in bits +template ValidateRsaPss(CHUNK_SIZE, CHUNK_NUMBER, KEY_LENGTH) { + signal input pubkey[CHUNK_NUMBER]; + signal input signature[CHUNK_NUMBER]; + + var fullChunks = KEY_LENGTH \ CHUNK_SIZE; + var remainingBits = KEY_LENGTH % CHUNK_SIZE; + + component sigBitChecks[CHUNK_NUMBER]; + component pubkeyBitChecks[CHUNK_NUMBER]; + + // Check value in each chunk can be represented in CHUNK_SIZE bits + for (var i = 0; i < fullChunks; i++) { + sigBitChecks[i] = Num2Bits(CHUNK_SIZE); + pubkeyBitChecks[i] = Num2Bits(CHUNK_SIZE); + sigBitChecks[i].in <== signature[i]; + pubkeyBitChecks[i].in <== pubkey[i]; + } + if (remainingBits > 0) { + sigBitChecks[fullChunks] = Num2Bits(remainingBits); + pubkeyBitChecks[fullChunks] = Num2Bits(remainingBits); + sigBitChecks[fullChunks].in <== signature[fullChunks]; + pubkeyBitChecks[fullChunks].in <== pubkey[fullChunks]; + } + //zero padding for remaining chunks + for(var i = fullChunks + 1; i < CHUNK_NUMBER; i++) { + signature[i] === 0; + pubkey[i] === 0; + } + + //signature cannot exceed public key modulus + component bigLessThan = BigLessThan(CHUNK_SIZE, CHUNK_NUMBER); + bigLessThan.a <== signature; + bigLessThan.b <== pubkey; + bigLessThan.out === 1; +} \ No newline at end of file diff --git a/circuits/tests/utils/generateMockInputsRsaPss.ts b/circuits/tests/utils/generateMockInputsRsaPss.ts index 827ebd36e..17739dcee 100644 --- a/circuits/tests/utils/generateMockInputsRsaPss.ts +++ b/circuits/tests/utils/generateMockInputsRsaPss.ts @@ -48,5 +48,52 @@ export const generateMockRsaPssInputs = ( signature: splitToWords(BigInt(bytesToBigDecimal(signature)), n, k), modulus: splitToWords(BigInt(hexToDecimal(modulus)), n, k), message: messageBits, + n, + k, + }; +}; + +export const generateMalleableRsaPssInputs = ( + signatureAlgorithm: SignatureAlgorithm, + saltLength: number +) => { + const [sigAlg, hashAlgorithm, exponent, modulusLength] = signatureAlgorithm.split('_'); + + // Generate RSA key pair + const keypair = forge.pki.rsa.generateKeyPair({ + bits: parseInt(modulusLength), + e: parseInt(exponent), + }); + + const message = 'helloworld'; + const md = forge.md[hashAlgorithm].create(); + md.update(forge.util.binary.raw.encode(Buffer.from(message))); + const messageHash = md.digest().bytes(); + + // Create valid signature + const pss = forge.pss.create({ + md: forge.md[hashAlgorithm].create(), + mgf: forge.mgf.mgf1.create(forge.md[hashAlgorithm].create()), + saltLength, + }); + + const signatureBytes = keypair.privateKey.sign(md, pss); + const signature = Array.from(signatureBytes, (c: string) => c.charCodeAt(0)); + + const modulus = BigInt('0x' + keypair.publicKey.n.toString(16)); + const sigValue = BigInt(bytesToBigDecimal(signature)); + const malleableValue = sigValue + modulus; + + const { n, k } = getNAndK(signatureAlgorithm); + + return { + signature: splitToWords(malleableValue, n, k), + modulus: splitToWords(modulus, n, k), + message: Array.from(messageHash) + .map((char: string) => { + const byte = char.charCodeAt(0); + return Array.from({ length: 8 }, (_, i) => (byte >> (7 - i)) & 1); + }) + .flat(), }; }; diff --git a/circuits/tests/utils/rsapss.test.ts b/circuits/tests/utils/rsapss.test.ts index 5f05d9179..31428d25d 100644 --- a/circuits/tests/utils/rsapss.test.ts +++ b/circuits/tests/utils/rsapss.test.ts @@ -1,34 +1,55 @@ import { wasm as wasmTester } from 'circom_tester'; import { describe, it } from 'mocha'; import path from 'path'; -import { SignatureAlgorithm } from '../../../common/src/utils/types'; -import { generateMockRsaPssInputs } from './generateMockInputsRsaPss'; +import { + generateMalleableRsaPssInputs, + generateMockRsaPssInputs, +} from './generateMockInputsRsaPss'; import { expect } from 'chai'; +import { fullAlgorithms, sigAlgs, AdditionalCases } from './testcase/rsapss'; describe('VerifyRsapss Circuit Test', function () { this.timeout(0); - const fullAlgorithms: { algo: SignatureAlgorithm; saltLength: number }[] = [ - { algo: 'rsapss_sha256_65537_4096', saltLength: 64 }, - { algo: 'rsapss_sha256_65537_3072', saltLength: 64 }, - { algo: 'rsapss_sha256_65537_2048', saltLength: 64 }, - { algo: 'rsapss_sha256_3_4096', saltLength: 64 }, - { algo: 'rsapss_sha256_3_3072', saltLength: 64 }, - { algo: 'rsapss_sha256_3_2048', saltLength: 64 }, - { algo: 'rsapss_sha512_3_4096', saltLength: 64 }, - { algo: 'rsapss_sha512_3_2048', saltLength: 64 }, - { algo: 'rsapss_sha384_65537_4096', saltLength: 48 }, - { algo: 'rsapss_sha384_65537_3072', saltLength: 48 }, - { algo: 'rsapss_sha384_3_4096', saltLength: 48 }, - { algo: 'rsapss_sha384_3_3072', saltLength: 48 }, - ]; - - const sigAlgs: { algo: SignatureAlgorithm; saltLength: number }[] = [ - { algo: 'rsapss_sha256_65537_4096', saltLength: 32 }, - ]; const testSuite = process.env.FULL_TEST_SUITE === 'true' ? fullAlgorithms : sigAlgs; testSuite.forEach((algorithm) => { + AdditionalCases[algorithm.algo]?.forEach((additionalCase) => { + it(`${additionalCase.title} for ${algorithm.algo}_${algorithm.saltLength} with additional case`, async function () { + this.timeout(0); + const signature = additionalCase.signature; + const modulus = additionalCase.modulus; + const message = additionalCase.message; + + const circuit = await wasmTester( + path.join( + __dirname, + `../../circuits/tests/utils/rsapss/test_${algorithm.algo}_${algorithm.saltLength}.circom` + ), + { + include: ['node_modules', './node_modules/@zk-kit/binary-merkle-root.circom/src'], + } + ); + + try { + const witness = await circuit.calculateWitness({ + signature, + modulus, + message, + }); + + // Check constraints + await circuit.checkConstraints(witness); + } catch (error) { + if (additionalCase.shouldFail) { + expect(error.message).to.include('Assert Failed'); + } else { + throw error; + } + } + }); + }); + it(`should verify RSA-PSS signature using the circuit for ${algorithm.algo}_${algorithm.saltLength}`, async function () { this.timeout(0); // Generate inputs using the utility function @@ -123,5 +144,95 @@ describe('VerifyRsapss Circuit Test', function () { expect(error.message).to.include('Assert Failed'); } }); + + it('Should reject signatures greater than or equal to modulus', async function () { + const { signature, modulus, message, n, k } = generateMockRsaPssInputs( + algorithm.algo, + algorithm.saltLength + ); + + const largeSignature = [...signature]; + largeSignature[k - 1] = String(BigInt(modulus[k - 1]) + 1n); + + const circuit = await wasmTester( + path.join( + __dirname, + `../../circuits/tests/utils/rsapss/test_${algorithm.algo}_${algorithm.saltLength}.circom` + ), + { + include: ['node_modules', './node_modules/@zk-kit/binary-merkle-root.circom/src'], + } + ); + + try { + await circuit.calculateWitness({ + signature: largeSignature, + modulus, + message, + }); + throw new Error('Circuit accepted signature >= modulus'); + } catch (error) { + expect(error.message).to.include('Assert Failed'); + } + }); + + it('Should reject malleable signatures (signature + modulus)', async function () { + const { signature, modulus, message } = generateMalleableRsaPssInputs( + algorithm.algo, + algorithm.saltLength + ); + + const circuit = await wasmTester( + path.join( + __dirname, + `../../circuits/tests/utils/rsapss/test_${algorithm.algo}_${algorithm.saltLength}.circom` + ), + { + include: ['node_modules', './node_modules/@zk-kit/binary-merkle-root.circom/src'], + } + ); + + try { + await circuit.calculateWitness({ + signature, + modulus, + message, + }); + throw new Error('Circuit accepted malleable signature'); + } catch (error) { + expect(error.message).to.include('Assert Failed'); + } + }); + + it('Should Fails when chunk has more bits than n', async function () { + const { signature, modulus, message } = generateMockRsaPssInputs( + algorithm.algo, + algorithm.saltLength + ); + + let overflowSignature = [...signature]; + overflowSignature[0] = String(BigInt(2) ** BigInt(122)); + + const circuit = await wasmTester( + path.join( + __dirname, + `../../circuits/tests/utils/rsapss/test_${algorithm.algo}_${algorithm.saltLength}.circom` + ), + { + include: ['node_modules', './node_modules/@zk-kit/binary-merkle-root.circom/src'], + } + ); + + try { + await circuit.calculateWitness({ + signature: overflowSignature, + modulus, + message, + }); + throw new Error('Circuit accepted malleable signature'); + } catch (error) { + expect(error.message).to.include('Assert Failed'); + } + }); }); }); diff --git a/circuits/tests/utils/testcase/rsapss.ts b/circuits/tests/utils/testcase/rsapss.ts new file mode 100644 index 000000000..df113e20a --- /dev/null +++ b/circuits/tests/utils/testcase/rsapss.ts @@ -0,0 +1,211 @@ +import { SignatureAlgorithm } from '../../../../common/src/utils/types'; + +export const fullAlgorithms: { algo: SignatureAlgorithm; saltLength: number }[] = [ + { algo: 'rsapss_sha256_65537_4096', saltLength: 32 }, + { algo: 'rsapss_sha256_65537_3072', saltLength: 32 }, + { algo: 'rsapss_sha256_65537_2048', saltLength: 32 }, + { algo: 'rsapss_sha256_65537_4096', saltLength: 64 }, + { algo: 'rsapss_sha256_65537_3072', saltLength: 64 }, + { algo: 'rsapss_sha256_65537_2048', saltLength: 64 }, + { algo: 'rsapss_sha256_3_4096', saltLength: 32 }, + { algo: 'rsapss_sha256_3_3072', saltLength: 32 }, + { algo: 'rsapss_sha256_3_2048', saltLength: 32 }, + { algo: 'rsapss_sha256_3_4096', saltLength: 64 }, + { algo: 'rsapss_sha256_3_3072', saltLength: 64 }, + { algo: 'rsapss_sha256_3_2048', saltLength: 64 }, + { algo: 'rsapss_sha512_3_4096', saltLength: 64 }, + { algo: 'rsapss_sha512_3_2048', saltLength: 64 }, + { algo: 'rsapss_sha384_65537_4096', saltLength: 48 }, + { algo: 'rsapss_sha384_65537_3072', saltLength: 48 }, + { algo: 'rsapss_sha384_3_4096', saltLength: 48 }, + { algo: 'rsapss_sha384_3_3072', saltLength: 48 }, +]; + +export const sigAlgs: { algo: SignatureAlgorithm; saltLength: number }[] = [ + { algo: 'rsapss_sha256_65537_2048', saltLength: 64 }, + { algo: 'rsapss_sha256_3_2048', saltLength: 64 }, +]; + +export const AdditionalCases = { + rsapss_sha256_3_2048: [ + { + title: 'Should fail - rsa_pss_signature_plus_modulus', + shouldFail: true, + signature: [ + '448593004166146035435698857233297481', + '645183938101488623213176183413849323', + '666034718174307596971224365157466704', + '1271375153014817405982734151086452096', + '1180028692934487107976994982225874491', + '43804972648824013081797601863424227', + '242516006679401681063343065419167823', + '1182227289007376475828724611303448394', + '901331680040946404805569866222078428', + '1112202048195884255403471214134025563', + '141640114810700046171697631501252046', + '1319213481608985222605784359236670013', + '1291596147188972332688115034547946543', + '1255959259511537862873460606790747882', + '1179627418594555107027925239956205308', + '1278704611743277638652199540388493869', + '286500237759217977031993526159467836', + '298', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + ], + modulus: [ + '424995376221342323916068129274876411', + '67032731905523875217374208431095706', + '1088902097090709882045251159028487271', + '905520378129685131655683371940143205', + '772653984210635001119928127332075925', + '453720758617437459118995325060964811', + '682423291494784049924695549058692937', + '787266768752181208854233249088291469', + '186472877058417959779133158879134681', + '464686717810020849123595839870153012', + '172023685318598391521750104755663284', + '515549827144062588662785848942123249', + '558628673475731478847640300679590539', + '92387777511544778109146896322836204', + '1061966765604557646220482768986492741', + '551767739591461330436511362339519451', + '394444902710212854489591990203694934', + '221', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + ], + message: [ + 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, + 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, + 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, + 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, + 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, + 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, + 1, 0, 1, 0, 0, 0, 1, 0, + ], + }, + + { + title: 'Should fail - rsa_pss_invalid_chunk_size', + shouldFail: true, + signature: [ + '7731115078616803665798114145181535098', + '163203956634050319114380978254630080', + '346148465904621811693353377067303234', + '64964293608834970395530903389524729', + '335420826842907861754475597205951544', + '925528679415886357007293672250413399', + '839020387053516695975203231151634658', + '778568500012672705243235513963101764', + '1237818354950660153047639164514793540', + '1279126110528363299554438433474920526', + '1271457987192236182994923269736835892', + '988175487321830418383164411072850446', + '887908374708917746102974749377227834', + '209746796442816091240476041947465923', + '1304363295406083456667488431125879054', + '524718671036833927121760315388886895', + '753317262042764462315754015065545818', + '121', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + ], + modulus: [ + '466890184860597456401263469298030355', + '934892081549427649323807927922900130', + '208645761278164609062745855133745252', + '706235177207591471475050602074473379', + '292598569458658692052640890836570796', + '495350572662119637520831264600583776', + '1195022894548414203114969601950177119', + '332800708856449316235710513302701152', + '45000490396361335068087327257520852', + '111185801182984893004870674922017989', + '421368615499870239882744392415817086', + '827966678350734062885488434550136996', + '77942854540484756138148180622094839', + '1043984843798011423344204006579183000', + '1118402313189367453676081212715001540', + '550698616548662574616349662954124741', + '904303165189377290183430808673064367', + '188', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + '0', + ], + message: [ + 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, + 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, + 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, + 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, + 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 0, 1, 0, 1, 1, 1, 1, + ], + }, + ], +};