Files
fhevm-solidity/codegen/generateOverloads.ts
2025-04-10 11:42:06 +02:00

407 lines
14 KiB
TypeScript

import { FheType } from './common';
import { findMinimumValueInBigIntArray, generateRandomNumber } from './utils';
/**
* Represents a test structure with input and output types.
*
* @property inputs - An array of `bigint` values representing the inputs.
* @property output - The output value, which can be a `number`, `boolean`, or `bigint`.
*/
type Test = {
inputs: bigint[];
output: number | boolean | bigint;
};
/**
* Represents a collection of supported functions, where each function is identified
* by a unique key and associated with its corresponding `SupportedFunction` definition.
*/
type SupportedFunctions = {
[key: string]: SupportedFunction;
};
/**
* Represents the parameters that can be used to configure supported functions.
*
* @property safeMin - Optional. If `true`, ensures that the minimum value is safely handled.
* @property noScalar - Optional. If `true`, disables scalar operations.
* @property lhsHigher - Optional. If `true`, enforces that the left-hand side value is higher.
* @property scalarOnly - Optional. If `true`, restricts operations to scalar values only.
* @property limit - Optional. Specifies a limit type, such as `'bits'`, to constrain operations.
*/
type SupportedFunctionParams = {
safeMin?: boolean;
noScalar?: boolean;
lhsHigher?: boolean;
scalarOnly?: boolean;
limit?: 'bits';
};
type SupportedFunction = SupportedFunctionParams &
(
| {
// Represents a binary function (e.g., addition, subtraction) that operates on two inputs.
unary?: false;
evalTest: (lhsNumber: bigint, rhsNumber: bigint, lhs: number, rhs: number) => number | boolean | bigint;
}
| {
// Represents a unary function (e.g., negation, bitwise NOT) that operates on a single input.
unary: true;
evalTest: (lhs: bigint, bits: number) => number | boolean | bigint;
}
);
/**
* Safely evaluates a function with given inputs, ensuring the result does not exceed a specified bit limit.
*
* @param fn - A function that takes four arguments: `lhsNumber`, `rhsNumber`, `lhs`, and `rhs`, and returns a `number`, `boolean`, or `bigint`.
* @param lhsNumber - The left-hand side number as a `bigint`.
* @param rhsNumber - The right-hand side number as a `bigint`.
* @param lhs - The left-hand side number as a `number`.
* @param rhs - The right-hand side number as a `number`.
* @param safeMin - A boolean flag indicating whether to use the minimum (`Math.min`) or maximum (`Math.max`) of `lhs` and `rhs` to determine the bit limit. Defaults to `false`.
* @returns An object containing:
* - `inputs`: An array with the adjusted `lhsNumber` and `rhsNumber`.
* - `output`: The result of the function `fn` after adjustments.
*/
const safeEval = (
fn: (lhsNumber: bigint, rhsNumber: bigint, lhs: number, rhs: number) => number | boolean | bigint,
lhsNumber: bigint,
rhsNumber: bigint,
lhs: number,
rhs: number,
safeMin: boolean = false,
) => {
const bitResults = safeMin ? Math.min(lhs, rhs) : Math.max(lhs, rhs);
let result = fn(lhsNumber, rhsNumber, lhs, rhs);
if (typeof result === 'number' || typeof result === 'bigint') {
while ((result as number | bigint) > Math.pow(2, bitResults) - 1) {
lhsNumber = lhsNumber / 2n + 1n;
rhsNumber = rhsNumber / 2n + 1n;
result = fn(lhsNumber, rhsNumber, lhs, rhs);
}
}
return { inputs: [lhsNumber, rhsNumber], output: result };
};
export const SUPPORTED_FUNCTIONS: SupportedFunctions = {
add: {
safeMin: true,
evalTest: (lhsNumber, rhsNumber) => BigInt(lhsNumber) + BigInt(rhsNumber),
},
sub: {
lhsHigher: true,
evalTest: (lhsNumber, rhsNumber) => lhsNumber - rhsNumber,
},
mul: {
safeMin: true,
evalTest: (lhsNumber, rhsNumber) => BigInt(lhsNumber) * BigInt(rhsNumber),
},
div: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber / rhsNumber,
scalarOnly: true,
},
rem: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber % rhsNumber,
scalarOnly: true,
},
le: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber <= rhsNumber,
},
lt: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber < rhsNumber,
},
ge: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber >= rhsNumber,
},
gt: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber > rhsNumber,
},
eq: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber === rhsNumber,
},
ne: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber !== rhsNumber,
},
shl: {
limit: 'bits',
evalTest: (lhsNumber, rhsNumber, lhs, _rhs) => {
// Perform a left shift operation by manipulating the bit positions of the binary representation.
const bits = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-lhs).split('');
const r = bits.map((_, index) => {
const newIndex = Number(BigInt(index) + (rhsNumber % BigInt(lhs)));
return newIndex >= bits.length ? '0' : bits[newIndex];
});
return BigInt(`0b${r.join('')}`);
},
},
shr: {
limit: 'bits',
evalTest: (lhsNumber, rhsNumber, lhs, _rhs) => {
const bits = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-lhs).split('');
const r = bits.map((_, index) => {
const newIndex = Number(BigInt(index) - (rhsNumber % BigInt(lhs)));
return newIndex < 0 ? '0' : bits[newIndex];
});
return BigInt(`0b${r.join('')}`);
},
},
rotl: {
limit: 'bits',
evalTest: (lhsNumber, rhsNumber, lhs, _rhs) => {
const bits = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-lhs).split('');
const r = bits.map((_, index) => {
let newIndex = Number(BigInt(index) + (rhsNumber % BigInt(lhs)));
if (newIndex >= lhs) newIndex = newIndex % lhs;
return bits[newIndex];
});
return BigInt(`0b${r.join('')}`);
},
},
rotr: {
limit: 'bits',
evalTest: (lhsNumber, rhsNumber, lhs, _rhs) => {
const bits = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-lhs).split('');
const r = bits.map((_, index) => {
let newIndex = Number(BigInt(index) - (rhsNumber % BigInt(lhs)));
if (newIndex < 0) newIndex = lhs + newIndex;
return bits[newIndex];
});
return BigInt(`0b${r.join('')}`);
},
},
max: {
unary: false,
evalTest: (lhsNumber, rhsNumber) => (lhsNumber > rhsNumber ? lhsNumber : rhsNumber),
},
min: {
evalTest: (lhsNumber, rhsNumber) => (lhsNumber < rhsNumber ? lhsNumber : rhsNumber),
},
or: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber | rhsNumber,
},
and: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber & rhsNumber,
},
xor: {
evalTest: (lhsNumber, rhsNumber) => lhsNumber ^ rhsNumber,
},
not: {
unary: true,
evalTest: (lhsNumber, bits) => {
const val = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-bits).split('');
return BigInt(
`0b${val
.map((v) => {
if (v === '1') return '0';
return '1';
})
.join('')}`,
);
},
},
neg: {
unary: true,
evalTest: (lhsNumber, bits) => {
const val = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-bits).split('');
return (
BigInt(
`0b${val
.map((v) => {
if (v === '1') return '0';
return '1';
})
.join('')}`,
) + 1n
);
},
},
};
/**
* Generates test cases for supported functions based on the provided FHE types.
*
* @param fheTypes - An array of FHE types, each containing information about type, bit length, and supported operators.
* @returns An object containing generated test cases for each supported function and FHE type combination.
*/
export const generateOverloads = (fheTypes: FheType[]) => {
const generatedTests: any = {};
Object.keys(SUPPORTED_FUNCTIONS).forEach((functionName: string) => {
const test = SUPPORTED_FUNCTIONS[functionName];
fheTypes.forEach((lhsFheType: FheType) => {
if (lhsFheType.type.startsWith('Uint') && lhsFheType.supportedOperators.includes(functionName)) {
if (test.unary) {
let lhsNumber = generateRandomNumber(lhsFheType.bitLength);
const encryptedTestName = [functionName, `e${lhsFheType.type.toLowerCase()}`].join('_');
const encryptedTests: Test[] = [];
encryptedTests.push({
inputs: [lhsNumber],
output: test.evalTest(lhsNumber, lhsFheType.bitLength),
});
generatedTests[encryptedTestName] = encryptedTests;
} else {
fheTypes.forEach((rhsFheType: FheType) => {
if (rhsFheType.type.startsWith('Uint') && rhsFheType.supportedOperators.includes(functionName)) {
const bitResults = Math.min(lhsFheType.bitLength, rhsFheType.bitLength);
let lhsNumber = generateRandomNumber(lhsFheType.bitLength);
let rhsNumber = generateRandomNumber(rhsFheType.bitLength);
if (test.limit === 'bits') {
// @dev We set the floor as 5 to prevent underflows since tests would use smallest - 4n.
rhsNumber = BigInt(5 + Math.floor(Math.random() * (rhsFheType.bitLength - 1)));
}
const smallest = findMinimumValueInBigIntArray(lhsNumber, rhsNumber);
const only8bits = test.limit === 'bits' && rhsFheType.bitLength === 8;
if ((test.limit !== 'bits' || only8bits) && !test.scalarOnly) {
const encryptedTestName = [
functionName,
`e${lhsFheType.type.toLowerCase()}`,
`e${rhsFheType.type.toLowerCase()}`,
].join('_');
const encryptedTests: Test[] = [];
if (!test.lhsHigher) {
encryptedTests.push(
safeEval(
test.evalTest,
lhsNumber,
rhsNumber,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
encryptedTests.push(
safeEval(
test.evalTest,
smallest - 4n,
smallest,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
}
encryptedTests.push(
safeEval(test.evalTest, smallest, smallest, lhsFheType.bitLength, rhsFheType.bitLength, test.safeMin),
);
encryptedTests.push(
safeEval(
test.evalTest,
smallest,
smallest - 4n,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
generatedTests[encryptedTestName] = encryptedTests;
}
const scalarCondition = !test.noScalar && lhsFheType.bitLength === rhsFheType.bitLength;
if (only8bits || (test.limit !== 'bits' && scalarCondition)) {
if (test.limit !== 'bits') {
rhsNumber = generateRandomNumber(bitResults);
}
const encryptedTestName = [
functionName,
`e${lhsFheType.type.toLowerCase()}`,
`uint${rhsFheType.bitLength}`,
].join('_');
const encryptedTests: Test[] = [];
if (!test.lhsHigher) {
encryptedTests.push(
safeEval(
test.evalTest,
lhsNumber,
rhsNumber,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
encryptedTests.push(
safeEval(
test.evalTest,
smallest - 4n,
smallest,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
}
encryptedTests.push(
safeEval(test.evalTest, smallest, smallest, lhsFheType.bitLength, rhsFheType.bitLength, test.safeMin),
);
encryptedTests.push(
safeEval(
test.evalTest,
smallest,
smallest - 4n,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
generatedTests[encryptedTestName] = encryptedTests;
}
if (test.limit !== 'bits' && scalarCondition && !test.scalarOnly) {
lhsNumber = generateRandomNumber(bitResults);
const encryptedTestName = [
functionName,
`uint${lhsFheType.bitLength}`,
`e${rhsFheType.type.toLowerCase()}`,
].join('_');
const encryptedTests: Test[] = [];
if (!test.lhsHigher) {
encryptedTests.push(
safeEval(
test.evalTest,
lhsNumber,
rhsNumber,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
encryptedTests.push(
safeEval(
test.evalTest,
smallest - 4n,
smallest,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
}
encryptedTests.push(
safeEval(test.evalTest, smallest, smallest, lhsFheType.bitLength, rhsFheType.bitLength, test.safeMin),
);
encryptedTests.push(
safeEval(
test.evalTest,
smallest,
smallest - 4n,
lhsFheType.bitLength,
rhsFheType.bitLength,
test.safeMin,
),
);
generatedTests[encryptedTestName] = encryptedTests;
}
}
});
}
}
});
});
return generatedTests;
};