mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
@@ -3,6 +3,8 @@ Python3_EXECUTABLE=
|
||||
BINDINGS_PYTHON_ENABLED=ON
|
||||
PARALLEL_EXECUTION_ENABLED=OFF
|
||||
|
||||
export PATH := $(BUILD_DIR)/bin:$(PATH)
|
||||
|
||||
ifeq ($(shell which ccache),)
|
||||
CCACHE=OFF
|
||||
else
|
||||
@@ -49,7 +51,7 @@ test-check: zamacompiler file-check not
|
||||
$(BUILD_DIR)/bin/llvm-lit -v tests/
|
||||
|
||||
test-python: python-bindings zamacompiler
|
||||
PATH=$(BUILD_DIR)/bin:${PATH} PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
|
||||
|
||||
test: test-check test-end-to-end-jit test-python
|
||||
|
||||
@@ -104,6 +106,21 @@ test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelizat
|
||||
|
||||
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg
|
||||
|
||||
show-stress-tests-summary:
|
||||
@echo '------ Stress tests summary ------'
|
||||
@echo
|
||||
@echo 'Rates:'
|
||||
@cd tests/stress_tests/trace && grep success_rate -R
|
||||
@echo
|
||||
@echo 'Parameters issues:'
|
||||
@cd tests/stress_tests/trace && grep BAD -R || echo 'No issues'
|
||||
|
||||
stress-tests: zamacompiler
|
||||
pytest -vs tests/stress_tests
|
||||
|
||||
# useful for faster cache generation, need pytest-parallel
|
||||
stress-tests-fast-cache: zamacompiler
|
||||
pytest --workers auto -vs tests/stress_tests
|
||||
|
||||
# LLVM/MLIR dependencies
|
||||
|
||||
|
||||
@@ -138,6 +138,10 @@ llvm::cl::list<uint64_t>
|
||||
llvm::cl::desc("Value of arguments to pass to the main func"),
|
||||
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
|
||||
|
||||
llvm::cl::opt<std::string> jitKeySetCachePath(
|
||||
"jit-keyset-cache-path",
|
||||
llvm::cl::desc("Path to cache KeySet content (unsecure)"));
|
||||
|
||||
llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser>
|
||||
assumeMaxEintPrecision(
|
||||
"assume-max-eint-precision",
|
||||
@@ -234,7 +238,9 @@ mlir::LogicalResult processInputBuffer(
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
llvm::Optional<llvm::ArrayRef<int64_t>> hlfhelinalgTileSizes,
|
||||
bool autoParallelize, llvm::raw_ostream &os,
|
||||
bool autoParallelize,
|
||||
llvm::Optional<mlir::zamalang::KeySetCache> keySetCache,
|
||||
llvm::raw_ostream &os,
|
||||
std::shared_ptr<mlir::zamalang::CompilerEngine::Library> outputLib) {
|
||||
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
|
||||
mlir::zamalang::CompilationContext::createShared();
|
||||
@@ -262,7 +268,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
ce.buildLambda(std::move(buffer), jitFuncName);
|
||||
ce.buildLambda(std::move(buffer), jitFuncName, keySetCache);
|
||||
|
||||
if (!lambdaOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
@@ -376,6 +382,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
if (!cmdline::hlfhelinalgTileSizes.empty())
|
||||
hlfhelinalgTileSizes.emplace(cmdline::hlfhelinalgTileSizes);
|
||||
|
||||
llvm::Optional<mlir::zamalang::KeySetCache> jitKeySetCache;
|
||||
if (!cmdline::jitKeySetCachePath.empty()) {
|
||||
jitKeySetCache = mlir::zamalang::KeySetCache(cmdline::jitKeySetCachePath);
|
||||
}
|
||||
|
||||
// In case of compilation to library, the real output is the library.
|
||||
std::string outputPath =
|
||||
(cmdline::action == Action::COMPILE) ? cmdline::STDOUT : cmdline::output;
|
||||
@@ -411,7 +422,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
cmdline::jitFuncName, cmdline::jitArgs,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, hlfhelinalgTileSizes,
|
||||
cmdline::autoParallelize, os, outputLib);
|
||||
cmdline::autoParallelize, jitKeySetCache, os, outputLib);
|
||||
};
|
||||
auto &os = output->os();
|
||||
auto res = mlir::failure();
|
||||
|
||||
@@ -34,3 +34,4 @@ def run(*cmd):
|
||||
if result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert result.returncode == 0, ' '.join(cmd)
|
||||
return str(result.stdout, encoding='utf-8')
|
||||
|
||||
70
compiler/tests/stress_tests/README.md
Normal file
70
compiler/tests/stress_tests/README.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# Run and display summary
|
||||
|
||||
## Run
|
||||
|
||||
You can:
|
||||
- ```make stress-tests```, tests replications are parallelized but tests are not parallelized
|
||||
- ```make stress-tests-fast```, tests and KeySetCache generation are parallelized, useful for a first run
|
||||
|
||||
## Summary
|
||||
|
||||
```make show-stress-tests```
|
||||
|
||||
# Raw results
|
||||
|
||||
In directory ```streestests/trace```:
|
||||
- ```test_controlled```, contains experiments with controlled code and parameters cases that should run fine
|
||||
- ```test_wild```, contains experiments with less controlled code and parameters that explores the limits of the compiler.
|
||||
|
||||
All experiment are currently a weighted sum with a constant weight followed by an identity function.
|
||||
|
||||
These two directories contains one experiment file per experiment, named ```XXXbits_x_YYY_W``` where XXX is the precision, YYY is the size of the computation and W is the experiment non structural parameter (here the weight in the sum).
|
||||
|
||||
Files are in json format but can easily be grepped (multi-lines).
|
||||
|
||||
# Experiment file
|
||||
|
||||
```json
|
||||
{
|
||||
# Command line to relauch an experiment replication by end
|
||||
"cmd": "zamacompiler /tmp/stresstests/basic_001_002_1.mlir --action=jit-invoke --jit-funcname=main --jit-args=1 --jit-args=1",
|
||||
# General information about the experiment
|
||||
"conditions": {
|
||||
"bitwidth": 1, # precision in bits
|
||||
"size": 2, # size of the computation
|
||||
"args": [ # jit arguments
|
||||
1,
|
||||
1
|
||||
],
|
||||
"log_manp_max": 3, # value comuted by zamacompiler
|
||||
"overflow": true, # does the exact computation overflow the precision
|
||||
"details": [
|
||||
"OVERFLOW"
|
||||
] # message related to potential issues
|
||||
},
|
||||
# Replications results
|
||||
"replications": [
|
||||
{
|
||||
"success": true,
|
||||
"details": []
|
||||
}, # A successful replication
|
||||
{
|
||||
"success": true,
|
||||
"details": [
|
||||
"OVERFLOW 3"
|
||||
]
|
||||
}, # A successful replication with the overflow value, result being correct when truncated
|
||||
{
|
||||
"success": false,
|
||||
"details": [
|
||||
"OVERFLOW 3",
|
||||
"Expected: 4 vs. 3 (no modulo 0 vs. 1)"
|
||||
]
|
||||
} # A failed replication when the result is wrong both directly and after truncation
|
||||
...
|
||||
],
|
||||
"code": "\nfunc @main(...) { ... }",
|
||||
"success_rate": 99.0,
|
||||
"overflow_rate": 100.0
|
||||
}
|
||||
```
|
||||
0
compiler/tests/stress_tests/__init__.py
Normal file
0
compiler/tests/stress_tests/__init__.py
Normal file
32
compiler/tests/stress_tests/experiment.py
Normal file
32
compiler/tests/stress_tests/experiment.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
import json
|
||||
|
||||
@dataclass
|
||||
class Replication:
|
||||
success: bool
|
||||
details: 'list[str]'
|
||||
|
||||
@dataclass
|
||||
class ExperimentConditions:
|
||||
bitwidth: int
|
||||
size: int
|
||||
args: 'list[int]'
|
||||
log_manp_max: int
|
||||
overflow: bool
|
||||
details: 'list[str]'
|
||||
|
||||
@dataclass
|
||||
class Experiment:
|
||||
cmd: str
|
||||
conditions: ExperimentConditions
|
||||
replications: 'list[Replication]'
|
||||
code: str
|
||||
success_rate: float
|
||||
overflow_rate: float
|
||||
|
||||
class Encoder(json.JSONEncoder):
|
||||
def default(self, z):
|
||||
try:
|
||||
return super().default(z)
|
||||
except:
|
||||
return asdict(z)
|
||||
31
compiler/tests/stress_tests/read_mlir.py
Normal file
31
compiler/tests/stress_tests/read_mlir.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
|
||||
from stress_tests.utils import ZAMACOMPILER, log2, ceil_log2, run
|
||||
|
||||
|
||||
DUMP_HLFHE = '--action=dump-hlfhe'
|
||||
DUMP_LOWLFHE = '--action=dump-lowlfhe'
|
||||
|
||||
def read_max_mlir_attribute(name, content):
|
||||
regexp = re.compile(f'{name} = (?P<value>[0-9]+)')
|
||||
return max(
|
||||
int(found.group('value'))
|
||||
for found in regexp.finditer(content)
|
||||
)
|
||||
|
||||
def log_manp_max(path):
|
||||
hlfhe = run(ZAMACOMPILER, path, DUMP_HLFHE)
|
||||
return ceil_log2(read_max_mlir_attribute('MANP', hlfhe))
|
||||
|
||||
@dataclass
|
||||
class FHEParams:
|
||||
log_poly_size: int
|
||||
glwe_dim: int
|
||||
|
||||
def v0_param(path):
|
||||
lowlfhe = run(ZAMACOMPILER, path, DUMP_LOWLFHE)
|
||||
return FHEParams(
|
||||
log_poly_size=log2(read_max_mlir_attribute('polynomialSize', lowlfhe)),
|
||||
glwe_dim=read_max_mlir_attribute('glweDimension', lowlfhe),
|
||||
)
|
||||
216
compiler/tests/stress_tests/test_stress.py
Normal file
216
compiler/tests/stress_tests/test_stress.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import contextlib
|
||||
import concurrent.futures as futures
|
||||
from itertools import chain
|
||||
import json
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
|
||||
import pytest
|
||||
|
||||
from stress_tests.experiment import (
|
||||
ExperimentConditions, Experiment, Encoder, Replication
|
||||
)
|
||||
from stress_tests import read_mlir
|
||||
from stress_tests.utils import ZAMACOMPILER, run
|
||||
from stress_tests.v0_parameters import (
|
||||
LOG2_MANP_MAX, P_MAX,
|
||||
v0_parameter
|
||||
)
|
||||
|
||||
POSSIBLE_BITWIDTH = range(1, P_MAX+1)
|
||||
POSSIBLE_SIZE = range(1, 128)
|
||||
|
||||
TEST_PATH = os.path.dirname(__file__)
|
||||
TRACE = os.path.join(TEST_PATH, 'trace')
|
||||
|
||||
JIT_INVOKE_MAIN = (
|
||||
'--action=jit-invoke',
|
||||
'--jit-funcname=main',
|
||||
'--jit-keyset-cache-path=/tmp/StresstestsCache',
|
||||
)
|
||||
|
||||
def jit_args(*params):
|
||||
return tuple(
|
||||
f'--jit-args={p}' for p in params
|
||||
)
|
||||
|
||||
CONTROLLED_CODE_PARAMS = sorted(chain.from_iterable(
|
||||
{ #(bitwidth, size, input value)
|
||||
(bitwidth, POSSIBLE_SIZE[-1], 0),
|
||||
(bitwidth, 1, 1),
|
||||
(bitwidth, bitwidth, 1),
|
||||
(bitwidth, 2 ** (bitwidth - 2), 1),
|
||||
(bitwidth, 2 ** (bitwidth - 1), 1),
|
||||
(bitwidth, 2 ** bitwidth - 1, 1),
|
||||
(bitwidth, 2 ** bitwidth, 1), # force carry
|
||||
(bitwidth, 2 ** (bitwidth+1), 1), # force overflow and carry 0 ?
|
||||
}# <-- a set to deduplicate similar cases
|
||||
for bitwidth in POSSIBLE_BITWIDTH
|
||||
))
|
||||
|
||||
|
||||
CONTROLLED_CODE_PARAMS = [
|
||||
case for case in CONTROLLED_CODE_PARAMS
|
||||
if case[1] >= 1
|
||||
]
|
||||
TEST_CONTROLLED_REPLICATE = 100
|
||||
|
||||
WILD_CODE_PARAMS = list(sorted(chain.from_iterable(
|
||||
{ #(bitwidth, size, input value)
|
||||
(bitwidth, 2 ** bitwidth + 8, 1),
|
||||
(bitwidth, 2 ** bitwidth + 9, 1),
|
||||
(bitwidth, 2 ** bitwidth + 16, 1),
|
||||
(bitwidth, 2 ** bitwidth + 17, 1),
|
||||
(bitwidth, 2 ** (2 * bitwidth), 1),
|
||||
(bitwidth, 2 ** (2 * bitwidth) + 1, 1),
|
||||
}# <-- a set to deduplicate similar cases
|
||||
for bitwidth in POSSIBLE_BITWIDTH
|
||||
)))
|
||||
TEST_WILD_RETRY = 3
|
||||
|
||||
def basic_multisum_identity(bitwidth, size):
|
||||
def components(name, size, ty=''):
|
||||
ty_annot = ' : ' + ty if ty else ''
|
||||
return ', '.join(f'%{name}{i}{ty_annot}' for i in range(size))
|
||||
def tensor(size, ty):
|
||||
return f'tensor<{size}x{ty}>'
|
||||
v_ty = f"!HLFHE.eint<{bitwidth}>"
|
||||
tv_ty = tensor(size, v_ty)
|
||||
w_ty = f"i{bitwidth+1}"
|
||||
w_modulo = 2 ** bitwidth # to match v bitwidth
|
||||
tw_ty = tensor(size, w_ty)
|
||||
lut_size = 2**bitwidth
|
||||
lut_ty = 'i64'
|
||||
tlut_ty = tensor(lut_size, lut_ty)
|
||||
|
||||
return (
|
||||
f"""
|
||||
func @main({components('v', size, v_ty)}) -> {v_ty} {{
|
||||
%v = tensor.from_elements {components('v', size)} : {tv_ty}
|
||||
|
||||
// Declare {size} %wX components
|
||||
{ ''.join(f'''
|
||||
%w{i} = arith.constant 1: {w_ty}'''
|
||||
for i in range(size)
|
||||
)}
|
||||
%w = tensor.from_elements {components('w', size)} : {tw_ty}
|
||||
|
||||
// Declare {lut_size} %lutX components
|
||||
{ ''.join(f'''
|
||||
%lut{i} = arith.constant {i}: i64'''
|
||||
for i in range(lut_size)
|
||||
)}
|
||||
%lut = tensor.from_elements {components('lut', lut_size)} : {tlut_ty}
|
||||
|
||||
%dot_product = "HLFHELinalg.dot_eint_int"(%v, %w) : ({tv_ty}, {tw_ty}) -> {v_ty}
|
||||
%pbs_result = "HLFHE.apply_lookup_table"(%dot_product, %lut): ({v_ty}, {tlut_ty}) -> {v_ty}
|
||||
return %pbs_result: {v_ty}
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
executor = futures.ThreadPoolExecutor()
|
||||
|
||||
def basic_setup(bitwidth, size, const, retry=10):
|
||||
code = basic_multisum_identity(bitwidth, size)
|
||||
args = (const,) * size
|
||||
expected = eval_basic_multisum_identity(bitwidth, args)
|
||||
with tmp_file(f'basic_{bitwidth:03}_{size:03}_{const}.mlir', code) as path:
|
||||
modulo = 2 ** bitwidth
|
||||
# Read various value from compiler
|
||||
log_manp_max = read_mlir.log_manp_max(path)
|
||||
params = read_mlir.v0_param(path)
|
||||
# From CPP source
|
||||
expected_params = v0_parameter(log_manp_max, bitwidth)
|
||||
expected_log_poly_size = expected_params.logPolynomialSize
|
||||
expected_glwe_dim = expected_params.glweDimension
|
||||
|
||||
conditions_details = []
|
||||
def msg(m, append_here=None, space=' '):
|
||||
print(m, end=space, flush=True) # test human output
|
||||
if append_here is not None:
|
||||
append_here.append(m)
|
||||
|
||||
if (LOG2_MANP_MAX < log_manp_max):
|
||||
msg('HIGH-MANP', conditions_details)
|
||||
if 2 ** bitwidth <= expected:
|
||||
msg(f'OVERFLOW', conditions_details)
|
||||
if params.log_poly_size != expected_log_poly_size:
|
||||
msg(f'BAD_LOGPOLYSIZE({params.log_poly_size} vs {expected_log_poly_size})', conditions_details)
|
||||
if params.glwe_dim != expected_glwe_dim:
|
||||
msg(f'BAD_GLWEDIM ({params.glwe_dim} vs {expected_glwe_dim})', conditions_details)
|
||||
|
||||
cmd = (ZAMACOMPILER, path) + JIT_INVOKE_MAIN + jit_args(*args)
|
||||
compilers_calls = [executor.submit(run, *cmd) for _ in range(retry)]
|
||||
|
||||
success = 0
|
||||
overflow = 0
|
||||
replications = []
|
||||
for replication in futures.as_completed(compilers_calls):
|
||||
result = int(replication.result().splitlines()[-1])
|
||||
correct_in_modulo = expected % modulo == result % modulo
|
||||
details = []
|
||||
replications.append(Replication(correct_in_modulo, details))
|
||||
if not (0 <= result < modulo):
|
||||
msg(f'OVERFLOW {result}', details)
|
||||
overflow += 1
|
||||
if correct_in_modulo:
|
||||
msg('O', space='')
|
||||
success += 1
|
||||
else:
|
||||
msg('X', space='')
|
||||
diff = f'Expected :{expected % modulo} vs. {result % modulo} (no modulo {expected} vs. {result}'
|
||||
details.append(diff)
|
||||
|
||||
print(' ', end='')
|
||||
add_to(TRACE, Experiment(
|
||||
cmd = ' '.join(cmd),
|
||||
conditions=ExperimentConditions(
|
||||
bitwidth=bitwidth, size=size, args=args,
|
||||
log_manp_max=log_manp_max,
|
||||
overflow=2 ** bitwidth <= expected,
|
||||
details=conditions_details,),
|
||||
replications=replications,
|
||||
code=code,
|
||||
success_rate=100.0 * success/retry,
|
||||
overflow_rate=100.0 * overflow/retry,
|
||||
))
|
||||
|
||||
assert success == len(replications)
|
||||
|
||||
|
||||
def eval_basic_multisum_identity(bitwidth, args):
|
||||
return sum(
|
||||
arg
|
||||
for arg in args
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tmp_file(name, content, delete=False):
|
||||
path = os.path.join(gettempdir(), 'stresstests', name)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
yield f.name
|
||||
if delete:
|
||||
os.remove()
|
||||
|
||||
def add_to(DIR, expe: Experiment):
|
||||
full_test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
test_name = full_test_name.rsplit('[', 1)[0]
|
||||
DIR = os.path.join(DIR, test_name)
|
||||
os.makedirs(DIR, exist_ok=True)
|
||||
conditions = expe.conditions
|
||||
name = f'{conditions.bitwidth:03}bits_x_{conditions.size:03}_{conditions.args[0]}'
|
||||
with open(os.path.join(DIR, name), 'w') as f:
|
||||
json.dump(expe, f, indent=2, cls=Encoder)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bitwidth, size, const", CONTROLLED_CODE_PARAMS)
|
||||
def test_controlled(bitwidth, size, const):
|
||||
return basic_setup(bitwidth, size, const, TEST_CONTROLLED_REPLICATE)
|
||||
|
||||
@pytest.mark.parametrize("bitwidth, size, const", WILD_CODE_PARAMS)
|
||||
def test_wild(bitwidth, size, const):
|
||||
return basic_setup(bitwidth, size, const, TEST_WILD_RETRY)
|
||||
19
compiler/tests/stress_tests/utils.py
Normal file
19
compiler/tests/stress_tests/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import subprocess
|
||||
|
||||
ZAMACOMPILER = 'zamacompiler'
|
||||
|
||||
def ceil_log2(v, exact=False):
|
||||
import math
|
||||
log_v = math.ceil(math.log(v) / math.log(2))
|
||||
assert not exact or v == 2 ** log_v
|
||||
return log_v
|
||||
|
||||
def log2(v: int):
|
||||
return ceil_log2(v, exact=True)
|
||||
|
||||
def run(*cmd):
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
if result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert result.returncode == 0, ' '.join(cmd)
|
||||
return str(result.stdout, encoding='utf-8')
|
||||
50
compiler/tests/stress_tests/v0_parameters.py
Normal file
50
compiler/tests/stress_tests/v0_parameters.py
Normal file
@@ -0,0 +1,50 @@
|
||||
""" Read parameters matrix in V0Parameters.cpp """
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
|
||||
@dataclass
|
||||
class V0Parameter:
|
||||
glweDimension: int
|
||||
logPolynomialSize:int
|
||||
nSmall: int
|
||||
brLevel: int
|
||||
brLogBase: int
|
||||
ksLevel: int
|
||||
ksLogBase:int
|
||||
|
||||
# [log_manp][bitwidth]
|
||||
v0_parameters : 'list[list[V0Parameter]]' = []
|
||||
|
||||
def v0_parameter(log_manp_max, bitwidth):
|
||||
try:
|
||||
return v0_parameters[log_manp_max - 1][bitwidth - 1]
|
||||
except IndexError:
|
||||
return V0Parameter( *( ['out_of_V0Parameters'] * 7) )
|
||||
|
||||
# relative to Makefile
|
||||
V0Parameters_PATH = 'lib/Support/V0Parameters.cpp'
|
||||
|
||||
def read_CPP_decl(name, cpp_filepath):
|
||||
DECLARE = re.compile(f'{name}[^=\n]+=')
|
||||
END = re.compile(';')
|
||||
with open(cpp_filepath) as f:
|
||||
content = f.read()
|
||||
decl = DECLARE.search(content)
|
||||
if not decl:
|
||||
raise NameError(f'Cannot find {name} declaration in file {cpp_filepath}')
|
||||
end = END.search(content, decl.end())
|
||||
assert end
|
||||
value = content[decl.end()+1:end.end()-1].strip()
|
||||
# print(f'{name} = {value}', decl.group())
|
||||
return value
|
||||
|
||||
LOG2_MANP_MAX = int(read_CPP_decl('NORM2_MAX', V0Parameters_PATH))
|
||||
P_MAX = int(read_CPP_decl('P_MAX', V0Parameters_PATH))
|
||||
|
||||
try:
|
||||
parameters_cpp = read_CPP_decl('parameters', V0Parameters_PATH)
|
||||
except (FileNotFoundError, NameError) as exc:
|
||||
print(exc)
|
||||
assert False
|
||||
|
||||
v0_parameters = eval(parameters_cpp.replace('{', '[').replace('}', ']'))
|
||||
Reference in New Issue
Block a user