feat(stress-tests): stress tests first case

Resolves #330
This commit is contained in:
rudy
2021-12-20 17:00:56 +01:00
committed by Ayoub Benaissa
parent cdffb0ec8e
commit 48197d55ac
10 changed files with 451 additions and 4 deletions

View File

@@ -3,6 +3,8 @@ Python3_EXECUTABLE=
BINDINGS_PYTHON_ENABLED=ON
PARALLEL_EXECUTION_ENABLED=OFF
export PATH := $(BUILD_DIR)/bin:$(PATH)
ifeq ($(shell which ccache),)
CCACHE=OFF
else
@@ -49,7 +51,7 @@ test-check: zamacompiler file-check not
$(BUILD_DIR)/bin/llvm-lit -v tests/
test-python: python-bindings zamacompiler
PATH=$(BUILD_DIR)/bin:${PATH} PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
test: test-check test-end-to-end-jit test-python
@@ -104,6 +106,21 @@ test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelizat
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg
show-stress-tests-summary:
@echo '------ Stress tests summary ------'
@echo
@echo 'Rates:'
@cd tests/stress_tests/trace && grep success_rate -R
@echo
@echo 'Parameters issues:'
@cd tests/stress_tests/trace && grep BAD -R || echo 'No issues'
stress-tests: zamacompiler
pytest -vs tests/stress_tests
# useful for faster cache generation, need pytest-parallel
stress-tests-fast-cache: zamacompiler
pytest --workers auto -vs tests/stress_tests
# LLVM/MLIR dependencies

View File

@@ -138,6 +138,10 @@ llvm::cl::list<uint64_t>
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
llvm::cl::opt<std::string> jitKeySetCachePath(
"jit-keyset-cache-path",
llvm::cl::desc("Path to cache KeySet content (unsecure)"));
llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser>
assumeMaxEintPrecision(
"assume-max-eint-precision",
@@ -234,7 +238,9 @@ mlir::LogicalResult processInputBuffer(
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
llvm::Optional<llvm::ArrayRef<int64_t>> hlfhelinalgTileSizes,
bool autoParallelize, llvm::raw_ostream &os,
bool autoParallelize,
llvm::Optional<mlir::zamalang::KeySetCache> keySetCache,
llvm::raw_ostream &os,
std::shared_ptr<mlir::zamalang::CompilerEngine::Library> outputLib) {
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
mlir::zamalang::CompilationContext::createShared();
@@ -262,7 +268,7 @@ mlir::LogicalResult processInputBuffer(
if (action == Action::JIT_INVOKE) {
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
ce.buildLambda(std::move(buffer), jitFuncName);
ce.buildLambda(std::move(buffer), jitFuncName, keySetCache);
if (!lambdaOrErr) {
mlir::zamalang::log_error()
@@ -376,6 +382,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
if (!cmdline::hlfhelinalgTileSizes.empty())
hlfhelinalgTileSizes.emplace(cmdline::hlfhelinalgTileSizes);
llvm::Optional<mlir::zamalang::KeySetCache> jitKeySetCache;
if (!cmdline::jitKeySetCachePath.empty()) {
jitKeySetCache = mlir::zamalang::KeySetCache(cmdline::jitKeySetCachePath);
}
// In case of compilation to library, the real output is the library.
std::string outputPath =
(cmdline::action == Action::COMPILE) ? cmdline::STDOUT : cmdline::output;
@@ -411,7 +422,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, hlfhelinalgTileSizes,
cmdline::autoParallelize, os, outputLib);
cmdline::autoParallelize, jitKeySetCache, os, outputLib);
};
auto &os = output->os();
auto res = mlir::failure();

View File

@@ -34,3 +34,4 @@ def run(*cmd):
if result.returncode != 0:
print(result.stderr)
assert result.returncode == 0, ' '.join(cmd)
return str(result.stdout, encoding='utf-8')

View File

@@ -0,0 +1,70 @@
# Run and display summary
## Run
You can:
- ```make stress-tests```, tests replications are parallelized but tests are not parallelized
- ```make stress-tests-fast```, tests and KeySetCache generation are parallelized, useful for a first run
## Summary
```make show-stress-tests```
# Raw results
In directory ```streestests/trace```:
- ```test_controlled```, contains experiments with controlled code and parameters cases that should run fine
- ```test_wild```, contains experiments with less controlled code and parameters that explores the limits of the compiler.
All experiment are currently a weighted sum with a constant weight followed by an identity function.
These two directories contains one experiment file per experiment, named ```XXXbits_x_YYY_W``` where XXX is the precision, YYY is the size of the computation and W is the experiment non structural parameter (here the weight in the sum).
Files are in json format but can easily be grepped (multi-lines).
# Experiment file
```json
{
# Command line to relauch an experiment replication by end
"cmd": "zamacompiler /tmp/stresstests/basic_001_002_1.mlir --action=jit-invoke --jit-funcname=main --jit-args=1 --jit-args=1",
# General information about the experiment
"conditions": {
"bitwidth": 1, # precision in bits
"size": 2, # size of the computation
"args": [ # jit arguments
1,
1
],
"log_manp_max": 3, # value comuted by zamacompiler
"overflow": true, # does the exact computation overflow the precision
"details": [
"OVERFLOW"
] # message related to potential issues 
},
# Replications results
"replications": [
{
"success": true,
"details": []
}, # A successful replication
{
"success": true,
"details": [
"OVERFLOW 3"
]
}, # A successful replication with the overflow value, result being correct when truncated
{
"success": false,
"details": [
"OVERFLOW 3",
"Expected: 4 vs. 3 (no modulo 0 vs. 1)"
]
} # A failed replication when the result is wrong both directly and after truncation
...
],
"code": "\nfunc @main(...) { ... }",
"success_rate": 99.0,
"overflow_rate": 100.0
}
```

View File

View File

@@ -0,0 +1,32 @@
from dataclasses import asdict, dataclass
import json
@dataclass
class Replication:
success: bool
details: 'list[str]'
@dataclass
class ExperimentConditions:
bitwidth: int
size: int
args: 'list[int]'
log_manp_max: int
overflow: bool
details: 'list[str]'
@dataclass
class Experiment:
cmd: str
conditions: ExperimentConditions
replications: 'list[Replication]'
code: str
success_rate: float
overflow_rate: float
class Encoder(json.JSONEncoder):
def default(self, z):
try:
return super().default(z)
except:
return asdict(z)

View File

@@ -0,0 +1,31 @@
from dataclasses import dataclass
import re
from stress_tests.utils import ZAMACOMPILER, log2, ceil_log2, run
DUMP_HLFHE = '--action=dump-hlfhe'
DUMP_LOWLFHE = '--action=dump-lowlfhe'
def read_max_mlir_attribute(name, content):
regexp = re.compile(f'{name} = (?P<value>[0-9]+)')
return max(
int(found.group('value'))
for found in regexp.finditer(content)
)
def log_manp_max(path):
hlfhe = run(ZAMACOMPILER, path, DUMP_HLFHE)
return ceil_log2(read_max_mlir_attribute('MANP', hlfhe))
@dataclass
class FHEParams:
log_poly_size: int
glwe_dim: int
def v0_param(path):
lowlfhe = run(ZAMACOMPILER, path, DUMP_LOWLFHE)
return FHEParams(
log_poly_size=log2(read_max_mlir_attribute('polynomialSize', lowlfhe)),
glwe_dim=read_max_mlir_attribute('glweDimension', lowlfhe),
)

View File

@@ -0,0 +1,216 @@
import contextlib
import concurrent.futures as futures
from itertools import chain
import json
import os
from tempfile import gettempdir
import pytest
from stress_tests.experiment import (
ExperimentConditions, Experiment, Encoder, Replication
)
from stress_tests import read_mlir
from stress_tests.utils import ZAMACOMPILER, run
from stress_tests.v0_parameters import (
LOG2_MANP_MAX, P_MAX,
v0_parameter
)
POSSIBLE_BITWIDTH = range(1, P_MAX+1)
POSSIBLE_SIZE = range(1, 128)
TEST_PATH = os.path.dirname(__file__)
TRACE = os.path.join(TEST_PATH, 'trace')
JIT_INVOKE_MAIN = (
'--action=jit-invoke',
'--jit-funcname=main',
'--jit-keyset-cache-path=/tmp/StresstestsCache',
)
def jit_args(*params):
return tuple(
f'--jit-args={p}' for p in params
)
CONTROLLED_CODE_PARAMS = sorted(chain.from_iterable(
{ #(bitwidth, size, input value)
(bitwidth, POSSIBLE_SIZE[-1], 0),
(bitwidth, 1, 1),
(bitwidth, bitwidth, 1),
(bitwidth, 2 ** (bitwidth - 2), 1),
(bitwidth, 2 ** (bitwidth - 1), 1),
(bitwidth, 2 ** bitwidth - 1, 1),
(bitwidth, 2 ** bitwidth, 1), # force carry
(bitwidth, 2 ** (bitwidth+1), 1), # force overflow and carry 0 ?
}# <-- a set to deduplicate similar cases
for bitwidth in POSSIBLE_BITWIDTH
))
CONTROLLED_CODE_PARAMS = [
case for case in CONTROLLED_CODE_PARAMS
if case[1] >= 1
]
TEST_CONTROLLED_REPLICATE = 100
WILD_CODE_PARAMS = list(sorted(chain.from_iterable(
{ #(bitwidth, size, input value)
(bitwidth, 2 ** bitwidth + 8, 1),
(bitwidth, 2 ** bitwidth + 9, 1),
(bitwidth, 2 ** bitwidth + 16, 1),
(bitwidth, 2 ** bitwidth + 17, 1),
(bitwidth, 2 ** (2 * bitwidth), 1),
(bitwidth, 2 ** (2 * bitwidth) + 1, 1),
}# <-- a set to deduplicate similar cases
for bitwidth in POSSIBLE_BITWIDTH
)))
TEST_WILD_RETRY = 3
def basic_multisum_identity(bitwidth, size):
def components(name, size, ty=''):
ty_annot = ' : ' + ty if ty else ''
return ', '.join(f'%{name}{i}{ty_annot}' for i in range(size))
def tensor(size, ty):
return f'tensor<{size}x{ty}>'
v_ty = f"!HLFHE.eint<{bitwidth}>"
tv_ty = tensor(size, v_ty)
w_ty = f"i{bitwidth+1}"
w_modulo = 2 ** bitwidth # to match v bitwidth
tw_ty = tensor(size, w_ty)
lut_size = 2**bitwidth
lut_ty = 'i64'
tlut_ty = tensor(lut_size, lut_ty)
return (
f"""
func @main({components('v', size, v_ty)}) -> {v_ty} {{
%v = tensor.from_elements {components('v', size)} : {tv_ty}
// Declare {size} %wX components
{ ''.join(f'''
%w{i} = arith.constant 1: {w_ty}'''
for i in range(size)
)}
%w = tensor.from_elements {components('w', size)} : {tw_ty}
// Declare {lut_size} %lutX components
{ ''.join(f'''
%lut{i} = arith.constant {i}: i64'''
for i in range(lut_size)
)}
%lut = tensor.from_elements {components('lut', lut_size)} : {tlut_ty}
%dot_product = "HLFHELinalg.dot_eint_int"(%v, %w) : ({tv_ty}, {tw_ty}) -> {v_ty}
%pbs_result = "HLFHE.apply_lookup_table"(%dot_product, %lut): ({v_ty}, {tlut_ty}) -> {v_ty}
return %pbs_result: {v_ty}
}}
"""
)
executor = futures.ThreadPoolExecutor()
def basic_setup(bitwidth, size, const, retry=10):
code = basic_multisum_identity(bitwidth, size)
args = (const,) * size
expected = eval_basic_multisum_identity(bitwidth, args)
with tmp_file(f'basic_{bitwidth:03}_{size:03}_{const}.mlir', code) as path:
modulo = 2 ** bitwidth
# Read various value from compiler
log_manp_max = read_mlir.log_manp_max(path)
params = read_mlir.v0_param(path)
# From CPP source
expected_params = v0_parameter(log_manp_max, bitwidth)
expected_log_poly_size = expected_params.logPolynomialSize
expected_glwe_dim = expected_params.glweDimension
conditions_details = []
def msg(m, append_here=None, space=' '):
print(m, end=space, flush=True) # test human output
if append_here is not None:
append_here.append(m)
if (LOG2_MANP_MAX < log_manp_max):
msg('HIGH-MANP', conditions_details)
if 2 ** bitwidth <= expected:
msg(f'OVERFLOW', conditions_details)
if params.log_poly_size != expected_log_poly_size:
msg(f'BAD_LOGPOLYSIZE({params.log_poly_size} vs {expected_log_poly_size})', conditions_details)
if params.glwe_dim != expected_glwe_dim:
msg(f'BAD_GLWEDIM ({params.glwe_dim} vs {expected_glwe_dim})', conditions_details)
cmd = (ZAMACOMPILER, path) + JIT_INVOKE_MAIN + jit_args(*args)
compilers_calls = [executor.submit(run, *cmd) for _ in range(retry)]
success = 0
overflow = 0
replications = []
for replication in futures.as_completed(compilers_calls):
result = int(replication.result().splitlines()[-1])
correct_in_modulo = expected % modulo == result % modulo
details = []
replications.append(Replication(correct_in_modulo, details))
if not (0 <= result < modulo):
msg(f'OVERFLOW {result}', details)
overflow += 1
if correct_in_modulo:
msg('O', space='')
success += 1
else:
msg('X', space='')
diff = f'Expected :{expected % modulo} vs. {result % modulo} (no modulo {expected} vs. {result}'
details.append(diff)
print(' ', end='')
add_to(TRACE, Experiment(
cmd = ' '.join(cmd),
conditions=ExperimentConditions(
bitwidth=bitwidth, size=size, args=args,
log_manp_max=log_manp_max,
overflow=2 ** bitwidth <= expected,
details=conditions_details,),
replications=replications,
code=code,
success_rate=100.0 * success/retry,
overflow_rate=100.0 * overflow/retry,
))
assert success == len(replications)
def eval_basic_multisum_identity(bitwidth, args):
return sum(
arg
for arg in args
)
@contextlib.contextmanager
def tmp_file(name, content, delete=False):
path = os.path.join(gettempdir(), 'stresstests', name)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as f:
f.write(content)
yield f.name
if delete:
os.remove()
def add_to(DIR, expe: Experiment):
full_test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
test_name = full_test_name.rsplit('[', 1)[0]
DIR = os.path.join(DIR, test_name)
os.makedirs(DIR, exist_ok=True)
conditions = expe.conditions
name = f'{conditions.bitwidth:03}bits_x_{conditions.size:03}_{conditions.args[0]}'
with open(os.path.join(DIR, name), 'w') as f:
json.dump(expe, f, indent=2, cls=Encoder)
@pytest.mark.parametrize("bitwidth, size, const", CONTROLLED_CODE_PARAMS)
def test_controlled(bitwidth, size, const):
return basic_setup(bitwidth, size, const, TEST_CONTROLLED_REPLICATE)
@pytest.mark.parametrize("bitwidth, size, const", WILD_CODE_PARAMS)
def test_wild(bitwidth, size, const):
return basic_setup(bitwidth, size, const, TEST_WILD_RETRY)

View File

@@ -0,0 +1,19 @@
import subprocess
ZAMACOMPILER = 'zamacompiler'
def ceil_log2(v, exact=False):
import math
log_v = math.ceil(math.log(v) / math.log(2))
assert not exact or v == 2 ** log_v
return log_v
def log2(v: int):
return ceil_log2(v, exact=True)
def run(*cmd):
result = subprocess.run(cmd, capture_output=True)
if result.returncode != 0:
print(result.stderr)
assert result.returncode == 0, ' '.join(cmd)
return str(result.stdout, encoding='utf-8')

View File

@@ -0,0 +1,50 @@
""" Read parameters matrix in V0Parameters.cpp """
from dataclasses import dataclass
import re
@dataclass
class V0Parameter:
glweDimension: int
logPolynomialSize:int
nSmall: int
brLevel: int
brLogBase: int
ksLevel: int
ksLogBase:int
# [log_manp][bitwidth]
v0_parameters : 'list[list[V0Parameter]]' = []
def v0_parameter(log_manp_max, bitwidth):
try:
return v0_parameters[log_manp_max - 1][bitwidth - 1]
except IndexError:
return V0Parameter( *( ['out_of_V0Parameters'] * 7) )
# relative to Makefile
V0Parameters_PATH = 'lib/Support/V0Parameters.cpp'
def read_CPP_decl(name, cpp_filepath):
DECLARE = re.compile(f'{name}[^=\n]+=')
END = re.compile(';')
with open(cpp_filepath) as f:
content = f.read()
decl = DECLARE.search(content)
if not decl:
raise NameError(f'Cannot find {name} declaration in file {cpp_filepath}')
end = END.search(content, decl.end())
assert end
value = content[decl.end()+1:end.end()-1].strip()
# print(f'{name} = {value}', decl.group())
return value
LOG2_MANP_MAX = int(read_CPP_decl('NORM2_MAX', V0Parameters_PATH))
P_MAX = int(read_CPP_decl('P_MAX', V0Parameters_PATH))
try:
parameters_cpp = read_CPP_decl('parameters', V0Parameters_PATH)
except (FileNotFoundError, NameError) as exc:
print(exc)
assert False
v0_parameters = eval(parameters_cpp.replace('{', '[').replace('}', ']'))