From 48197d55ac5fec89306aabd809ae810adb825e67 Mon Sep 17 00:00:00 2001 From: rudy Date: Mon, 20 Dec 2021 17:00:56 +0100 Subject: [PATCH] feat(stress-tests): stress tests first case Resolves #330 --- compiler/Makefile | 19 +- compiler/src/main.cpp | 17 +- .../python/test_compiler_file_output/utils.py | 1 + compiler/tests/stress_tests/README.md | 70 ++++++ compiler/tests/stress_tests/__init__.py | 0 compiler/tests/stress_tests/experiment.py | 32 +++ compiler/tests/stress_tests/read_mlir.py | 31 +++ compiler/tests/stress_tests/test_stress.py | 216 ++++++++++++++++++ compiler/tests/stress_tests/utils.py | 19 ++ compiler/tests/stress_tests/v0_parameters.py | 50 ++++ 10 files changed, 451 insertions(+), 4 deletions(-) create mode 100644 compiler/tests/stress_tests/README.md create mode 100644 compiler/tests/stress_tests/__init__.py create mode 100644 compiler/tests/stress_tests/experiment.py create mode 100644 compiler/tests/stress_tests/read_mlir.py create mode 100644 compiler/tests/stress_tests/test_stress.py create mode 100644 compiler/tests/stress_tests/utils.py create mode 100644 compiler/tests/stress_tests/v0_parameters.py diff --git a/compiler/Makefile b/compiler/Makefile index 0dcc3bb0d..ee51099b8 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -3,6 +3,8 @@ Python3_EXECUTABLE= BINDINGS_PYTHON_ENABLED=ON PARALLEL_EXECUTION_ENABLED=OFF +export PATH := $(BUILD_DIR)/bin:$(PATH) + ifeq ($(shell which ccache),) CCACHE=OFF else @@ -49,7 +51,7 @@ test-check: zamacompiler file-check not $(BUILD_DIR)/bin/llvm-lit -v tests/ test-python: python-bindings zamacompiler - PATH=$(BUILD_DIR)/bin:${PATH} PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python + PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python test: test-check test-end-to-end-jit test-python @@ -104,6 +106,21 @@ test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelizat test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg +show-stress-tests-summary: + @echo '------ Stress tests summary ------' + @echo + @echo 'Rates:' + @cd tests/stress_tests/trace && grep success_rate -R + @echo + @echo 'Parameters issues:' + @cd tests/stress_tests/trace && grep BAD -R || echo 'No issues' + +stress-tests: zamacompiler + pytest -vs tests/stress_tests + +# useful for faster cache generation, need pytest-parallel +stress-tests-fast-cache: zamacompiler + pytest --workers auto -vs tests/stress_tests # LLVM/MLIR dependencies diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 6e6dc6be5..3e6344278 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -138,6 +138,10 @@ llvm::cl::list llvm::cl::desc("Value of arguments to pass to the main func"), llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); +llvm::cl::opt jitKeySetCachePath( + "jit-keyset-cache-path", + llvm::cl::desc("Path to cache KeySet content (unsecure)")); + llvm::cl::opt, false, OptionalSizeTParser> assumeMaxEintPrecision( "assume-max-eint-precision", @@ -234,7 +238,9 @@ mlir::LogicalResult processInputBuffer( llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP, bool verifyDiagnostics, llvm::Optional> hlfhelinalgTileSizes, - bool autoParallelize, llvm::raw_ostream &os, + bool autoParallelize, + llvm::Optional keySetCache, + llvm::raw_ostream &os, std::shared_ptr outputLib) { std::shared_ptr ccx = mlir::zamalang::CompilationContext::createShared(); @@ -262,7 +268,7 @@ mlir::LogicalResult processInputBuffer( if (action == Action::JIT_INVOKE) { llvm::Expected lambdaOrErr = - ce.buildLambda(std::move(buffer), jitFuncName); + ce.buildLambda(std::move(buffer), jitFuncName, keySetCache); if (!lambdaOrErr) { mlir::zamalang::log_error() @@ -376,6 +382,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { if (!cmdline::hlfhelinalgTileSizes.empty()) hlfhelinalgTileSizes.emplace(cmdline::hlfhelinalgTileSizes); + llvm::Optional jitKeySetCache; + if (!cmdline::jitKeySetCachePath.empty()) { + jitKeySetCache = mlir::zamalang::KeySetCache(cmdline::jitKeySetCachePath); + } + // In case of compilation to library, the real output is the library. std::string outputPath = (cmdline::action == Action::COMPILE) ? cmdline::STDOUT : cmdline::output; @@ -411,7 +422,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { cmdline::jitFuncName, cmdline::jitArgs, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, hlfhelinalgTileSizes, - cmdline::autoParallelize, os, outputLib); + cmdline::autoParallelize, jitKeySetCache, os, outputLib); }; auto &os = output->os(); auto res = mlir::failure(); diff --git a/compiler/tests/python/test_compiler_file_output/utils.py b/compiler/tests/python/test_compiler_file_output/utils.py index c00389eb3..007cab502 100644 --- a/compiler/tests/python/test_compiler_file_output/utils.py +++ b/compiler/tests/python/test_compiler_file_output/utils.py @@ -34,3 +34,4 @@ def run(*cmd): if result.returncode != 0: print(result.stderr) assert result.returncode == 0, ' '.join(cmd) + return str(result.stdout, encoding='utf-8') diff --git a/compiler/tests/stress_tests/README.md b/compiler/tests/stress_tests/README.md new file mode 100644 index 000000000..ca6f6d05d --- /dev/null +++ b/compiler/tests/stress_tests/README.md @@ -0,0 +1,70 @@ +# Run and display summary + +## Run + +You can: +- ```make stress-tests```, tests replications are parallelized but tests are not parallelized +- ```make stress-tests-fast```, tests and KeySetCache generation are parallelized, useful for a first run + +## Summary + +```make show-stress-tests``` + +# Raw results + +In directory ```streestests/trace```: +- ```test_controlled```, contains experiments with controlled code and parameters cases that should run fine +- ```test_wild```, contains experiments with less controlled code and parameters that explores the limits of the compiler. + +All experiment are currently a weighted sum with a constant weight followed by an identity function. + +These two directories contains one experiment file per experiment, named ```XXXbits_x_YYY_W``` where XXX is the precision, YYY is the size of the computation and W is the experiment non structural parameter (here the weight in the sum). + +Files are in json format but can easily be grepped (multi-lines). + +# Experiment file + +```json +{ + # Command line to relauch an experiment replication by end + "cmd": "zamacompiler /tmp/stresstests/basic_001_002_1.mlir --action=jit-invoke --jit-funcname=main --jit-args=1 --jit-args=1", + # General information about the experiment + "conditions": { + "bitwidth": 1, # precision in bits + "size": 2, # size of the computation + "args": [ # jit arguments + 1, + 1 + ], + "log_manp_max": 3, # value comuted by zamacompiler + "overflow": true, # does the exact computation overflow the precision + "details": [ + "OVERFLOW" + ] # message related to potential issues  + }, + # Replications results + "replications": [ + { + "success": true, + "details": [] + }, # A successful replication + { + "success": true, + "details": [ + "OVERFLOW 3" + ] + }, # A successful replication with the overflow value, result being correct when truncated + { + "success": false, + "details": [ + "OVERFLOW 3", + "Expected: 4 vs. 3 (no modulo 0 vs. 1)" + ] + } # A failed replication when the result is wrong both directly and after truncation + ... + ], + "code": "\nfunc @main(...) { ... }", + "success_rate": 99.0, + "overflow_rate": 100.0 +} +``` diff --git a/compiler/tests/stress_tests/__init__.py b/compiler/tests/stress_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/tests/stress_tests/experiment.py b/compiler/tests/stress_tests/experiment.py new file mode 100644 index 000000000..d3007d4be --- /dev/null +++ b/compiler/tests/stress_tests/experiment.py @@ -0,0 +1,32 @@ +from dataclasses import asdict, dataclass +import json + +@dataclass +class Replication: + success: bool + details: 'list[str]' + +@dataclass +class ExperimentConditions: + bitwidth: int + size: int + args: 'list[int]' + log_manp_max: int + overflow: bool + details: 'list[str]' + +@dataclass +class Experiment: + cmd: str + conditions: ExperimentConditions + replications: 'list[Replication]' + code: str + success_rate: float + overflow_rate: float + +class Encoder(json.JSONEncoder): + def default(self, z): + try: + return super().default(z) + except: + return asdict(z) diff --git a/compiler/tests/stress_tests/read_mlir.py b/compiler/tests/stress_tests/read_mlir.py new file mode 100644 index 000000000..a0a8aed13 --- /dev/null +++ b/compiler/tests/stress_tests/read_mlir.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +import re + +from stress_tests.utils import ZAMACOMPILER, log2, ceil_log2, run + + +DUMP_HLFHE = '--action=dump-hlfhe' +DUMP_LOWLFHE = '--action=dump-lowlfhe' + +def read_max_mlir_attribute(name, content): + regexp = re.compile(f'{name} = (?P[0-9]+)') + return max( + int(found.group('value')) + for found in regexp.finditer(content) + ) + +def log_manp_max(path): + hlfhe = run(ZAMACOMPILER, path, DUMP_HLFHE) + return ceil_log2(read_max_mlir_attribute('MANP', hlfhe)) + +@dataclass +class FHEParams: + log_poly_size: int + glwe_dim: int + +def v0_param(path): + lowlfhe = run(ZAMACOMPILER, path, DUMP_LOWLFHE) + return FHEParams( + log_poly_size=log2(read_max_mlir_attribute('polynomialSize', lowlfhe)), + glwe_dim=read_max_mlir_attribute('glweDimension', lowlfhe), + ) diff --git a/compiler/tests/stress_tests/test_stress.py b/compiler/tests/stress_tests/test_stress.py new file mode 100644 index 000000000..d3aeacbbe --- /dev/null +++ b/compiler/tests/stress_tests/test_stress.py @@ -0,0 +1,216 @@ +import contextlib +import concurrent.futures as futures +from itertools import chain +import json +import os +from tempfile import gettempdir + +import pytest + +from stress_tests.experiment import ( + ExperimentConditions, Experiment, Encoder, Replication +) +from stress_tests import read_mlir +from stress_tests.utils import ZAMACOMPILER, run +from stress_tests.v0_parameters import ( + LOG2_MANP_MAX, P_MAX, + v0_parameter +) + +POSSIBLE_BITWIDTH = range(1, P_MAX+1) +POSSIBLE_SIZE = range(1, 128) + +TEST_PATH = os.path.dirname(__file__) +TRACE = os.path.join(TEST_PATH, 'trace') + +JIT_INVOKE_MAIN = ( + '--action=jit-invoke', + '--jit-funcname=main', + '--jit-keyset-cache-path=/tmp/StresstestsCache', +) + +def jit_args(*params): + return tuple( + f'--jit-args={p}' for p in params + ) + +CONTROLLED_CODE_PARAMS = sorted(chain.from_iterable( + { #(bitwidth, size, input value) + (bitwidth, POSSIBLE_SIZE[-1], 0), + (bitwidth, 1, 1), + (bitwidth, bitwidth, 1), + (bitwidth, 2 ** (bitwidth - 2), 1), + (bitwidth, 2 ** (bitwidth - 1), 1), + (bitwidth, 2 ** bitwidth - 1, 1), + (bitwidth, 2 ** bitwidth, 1), # force carry + (bitwidth, 2 ** (bitwidth+1), 1), # force overflow and carry 0 ? + }# <-- a set to deduplicate similar cases + for bitwidth in POSSIBLE_BITWIDTH +)) + + +CONTROLLED_CODE_PARAMS = [ + case for case in CONTROLLED_CODE_PARAMS + if case[1] >= 1 +] +TEST_CONTROLLED_REPLICATE = 100 + +WILD_CODE_PARAMS = list(sorted(chain.from_iterable( + { #(bitwidth, size, input value) + (bitwidth, 2 ** bitwidth + 8, 1), + (bitwidth, 2 ** bitwidth + 9, 1), + (bitwidth, 2 ** bitwidth + 16, 1), + (bitwidth, 2 ** bitwidth + 17, 1), + (bitwidth, 2 ** (2 * bitwidth), 1), + (bitwidth, 2 ** (2 * bitwidth) + 1, 1), + }# <-- a set to deduplicate similar cases + for bitwidth in POSSIBLE_BITWIDTH +))) +TEST_WILD_RETRY = 3 + +def basic_multisum_identity(bitwidth, size): + def components(name, size, ty=''): + ty_annot = ' : ' + ty if ty else '' + return ', '.join(f'%{name}{i}{ty_annot}' for i in range(size)) + def tensor(size, ty): + return f'tensor<{size}x{ty}>' + v_ty = f"!HLFHE.eint<{bitwidth}>" + tv_ty = tensor(size, v_ty) + w_ty = f"i{bitwidth+1}" + w_modulo = 2 ** bitwidth # to match v bitwidth + tw_ty = tensor(size, w_ty) + lut_size = 2**bitwidth + lut_ty = 'i64' + tlut_ty = tensor(lut_size, lut_ty) + + return ( +f""" +func @main({components('v', size, v_ty)}) -> {v_ty} {{ + %v = tensor.from_elements {components('v', size)} : {tv_ty} + + // Declare {size} %wX components + { ''.join(f''' + %w{i} = arith.constant 1: {w_ty}''' + for i in range(size) + )} + %w = tensor.from_elements {components('w', size)} : {tw_ty} + + // Declare {lut_size} %lutX components + { ''.join(f''' + %lut{i} = arith.constant {i}: i64''' + for i in range(lut_size) + )} + %lut = tensor.from_elements {components('lut', lut_size)} : {tlut_ty} + + %dot_product = "HLFHELinalg.dot_eint_int"(%v, %w) : ({tv_ty}, {tw_ty}) -> {v_ty} + %pbs_result = "HLFHE.apply_lookup_table"(%dot_product, %lut): ({v_ty}, {tlut_ty}) -> {v_ty} + return %pbs_result: {v_ty} +}} +""" + ) + + +executor = futures.ThreadPoolExecutor() + +def basic_setup(bitwidth, size, const, retry=10): + code = basic_multisum_identity(bitwidth, size) + args = (const,) * size + expected = eval_basic_multisum_identity(bitwidth, args) + with tmp_file(f'basic_{bitwidth:03}_{size:03}_{const}.mlir', code) as path: + modulo = 2 ** bitwidth + # Read various value from compiler + log_manp_max = read_mlir.log_manp_max(path) + params = read_mlir.v0_param(path) + # From CPP source + expected_params = v0_parameter(log_manp_max, bitwidth) + expected_log_poly_size = expected_params.logPolynomialSize + expected_glwe_dim = expected_params.glweDimension + + conditions_details = [] + def msg(m, append_here=None, space=' '): + print(m, end=space, flush=True) # test human output + if append_here is not None: + append_here.append(m) + + if (LOG2_MANP_MAX < log_manp_max): + msg('HIGH-MANP', conditions_details) + if 2 ** bitwidth <= expected: + msg(f'OVERFLOW', conditions_details) + if params.log_poly_size != expected_log_poly_size: + msg(f'BAD_LOGPOLYSIZE({params.log_poly_size} vs {expected_log_poly_size})', conditions_details) + if params.glwe_dim != expected_glwe_dim: + msg(f'BAD_GLWEDIM ({params.glwe_dim} vs {expected_glwe_dim})', conditions_details) + + cmd = (ZAMACOMPILER, path) + JIT_INVOKE_MAIN + jit_args(*args) + compilers_calls = [executor.submit(run, *cmd) for _ in range(retry)] + + success = 0 + overflow = 0 + replications = [] + for replication in futures.as_completed(compilers_calls): + result = int(replication.result().splitlines()[-1]) + correct_in_modulo = expected % modulo == result % modulo + details = [] + replications.append(Replication(correct_in_modulo, details)) + if not (0 <= result < modulo): + msg(f'OVERFLOW {result}', details) + overflow += 1 + if correct_in_modulo: + msg('O', space='') + success += 1 + else: + msg('X', space='') + diff = f'Expected :{expected % modulo} vs. {result % modulo} (no modulo {expected} vs. {result}' + details.append(diff) + + print(' ', end='') + add_to(TRACE, Experiment( + cmd = ' '.join(cmd), + conditions=ExperimentConditions( + bitwidth=bitwidth, size=size, args=args, + log_manp_max=log_manp_max, + overflow=2 ** bitwidth <= expected, + details=conditions_details,), + replications=replications, + code=code, + success_rate=100.0 * success/retry, + overflow_rate=100.0 * overflow/retry, + )) + + assert success == len(replications) + + +def eval_basic_multisum_identity(bitwidth, args): + return sum( + arg + for arg in args + ) + +@contextlib.contextmanager +def tmp_file(name, content, delete=False): + path = os.path.join(gettempdir(), 'stresstests', name) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + f.write(content) + yield f.name + if delete: + os.remove() + +def add_to(DIR, expe: Experiment): + full_test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + test_name = full_test_name.rsplit('[', 1)[0] + DIR = os.path.join(DIR, test_name) + os.makedirs(DIR, exist_ok=True) + conditions = expe.conditions + name = f'{conditions.bitwidth:03}bits_x_{conditions.size:03}_{conditions.args[0]}' + with open(os.path.join(DIR, name), 'w') as f: + json.dump(expe, f, indent=2, cls=Encoder) + + +@pytest.mark.parametrize("bitwidth, size, const", CONTROLLED_CODE_PARAMS) +def test_controlled(bitwidth, size, const): + return basic_setup(bitwidth, size, const, TEST_CONTROLLED_REPLICATE) + +@pytest.mark.parametrize("bitwidth, size, const", WILD_CODE_PARAMS) +def test_wild(bitwidth, size, const): + return basic_setup(bitwidth, size, const, TEST_WILD_RETRY) diff --git a/compiler/tests/stress_tests/utils.py b/compiler/tests/stress_tests/utils.py new file mode 100644 index 000000000..94163edd2 --- /dev/null +++ b/compiler/tests/stress_tests/utils.py @@ -0,0 +1,19 @@ +import subprocess + +ZAMACOMPILER = 'zamacompiler' + +def ceil_log2(v, exact=False): + import math + log_v = math.ceil(math.log(v) / math.log(2)) + assert not exact or v == 2 ** log_v + return log_v + +def log2(v: int): + return ceil_log2(v, exact=True) + +def run(*cmd): + result = subprocess.run(cmd, capture_output=True) + if result.returncode != 0: + print(result.stderr) + assert result.returncode == 0, ' '.join(cmd) + return str(result.stdout, encoding='utf-8') diff --git a/compiler/tests/stress_tests/v0_parameters.py b/compiler/tests/stress_tests/v0_parameters.py new file mode 100644 index 000000000..a38277086 --- /dev/null +++ b/compiler/tests/stress_tests/v0_parameters.py @@ -0,0 +1,50 @@ +""" Read parameters matrix in V0Parameters.cpp """ +from dataclasses import dataclass +import re + +@dataclass +class V0Parameter: + glweDimension: int + logPolynomialSize:int + nSmall: int + brLevel: int + brLogBase: int + ksLevel: int + ksLogBase:int + +# [log_manp][bitwidth] +v0_parameters : 'list[list[V0Parameter]]' = [] + +def v0_parameter(log_manp_max, bitwidth): + try: + return v0_parameters[log_manp_max - 1][bitwidth - 1] + except IndexError: + return V0Parameter( *( ['out_of_V0Parameters'] * 7) ) + +# relative to Makefile +V0Parameters_PATH = 'lib/Support/V0Parameters.cpp' + +def read_CPP_decl(name, cpp_filepath): + DECLARE = re.compile(f'{name}[^=\n]+=') + END = re.compile(';') + with open(cpp_filepath) as f: + content = f.read() + decl = DECLARE.search(content) + if not decl: + raise NameError(f'Cannot find {name} declaration in file {cpp_filepath}') + end = END.search(content, decl.end()) + assert end + value = content[decl.end()+1:end.end()-1].strip() + # print(f'{name} = {value}', decl.group()) + return value + +LOG2_MANP_MAX = int(read_CPP_decl('NORM2_MAX', V0Parameters_PATH)) +P_MAX = int(read_CPP_decl('P_MAX', V0Parameters_PATH)) + +try: + parameters_cpp = read_CPP_decl('parameters', V0Parameters_PATH) +except (FileNotFoundError, NameError) as exc: + print(exc) + assert False + +v0_parameters = eval(parameters_cpp.replace('{', '[').replace('}', ']'))