mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(benchmarks): update some bounds, bump python library, change accuracy calculation method
This commit is contained in:
2
Makefile
2
Makefile
@@ -60,7 +60,7 @@ pylint_tests:
|
||||
pylint_benchmarks:
|
||||
@# Disable duplicate code detection, docstring requirement, too many locals/statements
|
||||
find ./benchmarks/ -type f -name "*.py" | xargs poetry run pylint \
|
||||
--disable=R0801,R0914,R0915,C0103,C0114,C0115,C0116,W0108 --rcfile=pylintrc
|
||||
--disable=R0801,R0914,R0915,C0103,C0114,C0115,C0116,C0302,W0108 --rcfile=pylintrc
|
||||
|
||||
.PHONY: pylint_script # Run pylint on scripts
|
||||
pylint_script:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# pylint: disable=too-many-lines
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import progress
|
||||
from common import BENCHMARK_CONFIGURATION
|
||||
|
||||
import concrete.numpy as hnp
|
||||
|
||||
@@ -813,7 +813,7 @@ import concrete.numpy as hnp
|
||||
"type": "encrypted",
|
||||
"shape": (2, 3),
|
||||
"minimum": 0,
|
||||
"maximum": 30,
|
||||
"maximum": 25,
|
||||
},
|
||||
},
|
||||
"accuracy_alert_threshold": 100,
|
||||
@@ -829,13 +829,13 @@ import concrete.numpy as hnp
|
||||
"type": "encrypted",
|
||||
"shape": (2, 3),
|
||||
"minimum": 0,
|
||||
"maximum": 14,
|
||||
"maximum": 15,
|
||||
},
|
||||
"y": {
|
||||
"type": "encrypted",
|
||||
"shape": (3, 2),
|
||||
"minimum": 0,
|
||||
"maximum": 3,
|
||||
"maximum": 4,
|
||||
},
|
||||
},
|
||||
"accuracy_alert_threshold": 100,
|
||||
@@ -975,14 +975,14 @@ import concrete.numpy as hnp
|
||||
"function": lambda x: hnp.MultiLookupTable(
|
||||
[
|
||||
[
|
||||
hnp.LookupTable([(i ** 5) + 2 % 32 for i in range(32)]),
|
||||
hnp.LookupTable([(i ** 5) * 3 % 32 for i in range(32)]),
|
||||
hnp.LookupTable([(i ** 5) // 6 % 32 for i in range(32)]),
|
||||
hnp.LookupTable([((i ** 5) + 2) % 32 for i in range(32)]),
|
||||
hnp.LookupTable([((i ** 5) * 3) % 32 for i in range(32)]),
|
||||
hnp.LookupTable([((i ** 5) // 6) % 32 for i in range(32)]),
|
||||
],
|
||||
[
|
||||
hnp.LookupTable([(i ** 5) // 2 % 32 for i in range(32)]),
|
||||
hnp.LookupTable([(i ** 5) + 5 % 32 for i in range(32)]),
|
||||
hnp.LookupTable([(i ** 5) * 4 % 32 for i in range(32)]),
|
||||
hnp.LookupTable([((i ** 5) // 2) % 32 for i in range(32)]),
|
||||
hnp.LookupTable([((i ** 5) + 5) % 32 for i in range(32)]),
|
||||
hnp.LookupTable([((i ** 5) * 4) % 32 for i in range(32)]),
|
||||
],
|
||||
]
|
||||
)[x],
|
||||
@@ -1035,14 +1035,14 @@ import concrete.numpy as hnp
|
||||
"function": lambda x: hnp.MultiLookupTable(
|
||||
[
|
||||
[
|
||||
hnp.LookupTable([(i ** 6) + 2 % 64 for i in range(64)]),
|
||||
hnp.LookupTable([(i ** 6) * 3 % 64 for i in range(64)]),
|
||||
hnp.LookupTable([(i ** 6) // 6 % 64 for i in range(64)]),
|
||||
hnp.LookupTable([((i ** 6) + 2) % 64 for i in range(64)]),
|
||||
hnp.LookupTable([((i ** 6) * 3) % 64 for i in range(64)]),
|
||||
hnp.LookupTable([((i ** 6) // 6) % 64 for i in range(64)]),
|
||||
],
|
||||
[
|
||||
hnp.LookupTable([(i ** 6) // 2 % 64 for i in range(64)]),
|
||||
hnp.LookupTable([(i ** 6) + 5 % 64 for i in range(64)]),
|
||||
hnp.LookupTable([(i ** 6) * 4 % 64 for i in range(64)]),
|
||||
hnp.LookupTable([((i ** 6) // 2) % 64 for i in range(64)]),
|
||||
hnp.LookupTable([((i ** 6) + 5) % 64 for i in range(64)]),
|
||||
hnp.LookupTable([((i ** 6) * 4) % 64 for i in range(64)]),
|
||||
],
|
||||
]
|
||||
)[x],
|
||||
@@ -1095,14 +1095,14 @@ import concrete.numpy as hnp
|
||||
"function": lambda x: hnp.MultiLookupTable(
|
||||
[
|
||||
[
|
||||
hnp.LookupTable([(i ** 7) + 2 % 128 for i in range(128)]),
|
||||
hnp.LookupTable([(i ** 7) * 3 % 128 for i in range(128)]),
|
||||
hnp.LookupTable([(i ** 7) // 6 % 128 for i in range(128)]),
|
||||
hnp.LookupTable([((i ** 7) + 2) % 128 for i in range(128)]),
|
||||
hnp.LookupTable([((i ** 7) * 3) % 128 for i in range(128)]),
|
||||
hnp.LookupTable([((i ** 7) // 6) % 128 for i in range(128)]),
|
||||
],
|
||||
[
|
||||
hnp.LookupTable([(i ** 7) // 2 % 128 for i in range(128)]),
|
||||
hnp.LookupTable([(i ** 7) + 5 % 128 for i in range(128)]),
|
||||
hnp.LookupTable([(i ** 7) * 4 % 128 for i in range(128)]),
|
||||
hnp.LookupTable([((i ** 7) // 2) % 128 for i in range(128)]),
|
||||
hnp.LookupTable([((i ** 7) + 5) % 128 for i in range(128)]),
|
||||
hnp.LookupTable([((i ** 7) * 4) % 128 for i in range(128)]),
|
||||
],
|
||||
]
|
||||
)[x],
|
||||
@@ -1510,7 +1510,9 @@ def main(function, inputs, accuracy_alert_threshold):
|
||||
inputset.append(tuple(input_) if len(input_) > 1 else input_[0])
|
||||
|
||||
compiler = hnp.NPFHECompiler(
|
||||
function, {name: description["type"] for name, description in inputs.items()}
|
||||
function,
|
||||
{name: description["type"] for name, description in inputs.items()},
|
||||
compilation_configuration=BENCHMARK_CONFIGURATION,
|
||||
)
|
||||
|
||||
circuit = compiler.compile_on_inputset(inputset)
|
||||
@@ -1540,8 +1542,11 @@ def main(function, inputs, accuracy_alert_threshold):
|
||||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"):
|
||||
result_i = circuit.run(*sample_i)
|
||||
|
||||
if np.array_equal(result_i, expectation_i):
|
||||
correct += 1
|
||||
np_result_i = np.array(result_i, dtype=np.uint8)
|
||||
np_expectation_i = np.array(expectation_i, dtype=np.uint8)
|
||||
|
||||
if np_result_i.shape == np_expectation_i.shape:
|
||||
correct += np.sum(np_result_i == np_expectation_i) / np_result_i.size
|
||||
accuracy = (correct / len(samples)) * 100
|
||||
|
||||
print(f"Accuracy (%): {accuracy:.4f}")
|
||||
|
||||
@@ -157,7 +157,7 @@ def score_concrete_glm_estimator(poisson_glm_pca, q_glm, df_test):
|
||||
return score_estimator(y_pred, df_test["Frequency"], df_test["Exposure"])
|
||||
|
||||
|
||||
@progress.track([{"id": "glm", "name": "Generalized Linear Model", "parameters": {}}])
|
||||
@progress.track([{"id": "glm", "name": "Generalized Linear Model"}])
|
||||
def main():
|
||||
"""
|
||||
This is our main benchmark function. It gets a dataset, trains a GLM model,
|
||||
|
||||
@@ -94,7 +94,7 @@ class QuantizedLinearRegression(QuantizedModule):
|
||||
return q_input_arr
|
||||
|
||||
|
||||
@progress.track([{"id": "linear-regression", "name": "Linear Regression", "parameters": {}}])
|
||||
@progress.track([{"id": "linear-regression", "name": "Linear Regression"}])
|
||||
def main():
|
||||
"""
|
||||
Our linear regression benchmark. Use some synthetic data to train a regression model,
|
||||
|
||||
@@ -116,7 +116,7 @@ class QuantizedLogisticRegression(QuantizedModule):
|
||||
return q_input_arr
|
||||
|
||||
|
||||
@progress.track([{"id": "logistic-regression", "name": "Logistic Regression", "parameters": {}}])
|
||||
@progress.track([{"id": "logistic-regression", "name": "Logistic Regression"}])
|
||||
def main():
|
||||
"""Main benchmark function: generate some synthetic data for two class classification,
|
||||
split train-test, train a sklearn classifier, calibrate and quantize it on the whole dataset
|
||||
|
||||
14
poetry.lock
generated
14
poetry.lock
generated
@@ -1354,7 +1354,7 @@ python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "py-progress-tracker"
|
||||
version = "0.1.0"
|
||||
version = "0.3.3"
|
||||
description = "A simple benchmarking library"
|
||||
category = "dev"
|
||||
optional = false
|
||||
@@ -1742,7 +1742,7 @@ md = ["cmarkgfm (>=0.5.0,<0.7.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.27.0"
|
||||
version = "2.27.1"
|
||||
description = "Python HTTP for Humans."
|
||||
category = "dev"
|
||||
optional = false
|
||||
@@ -2282,7 +2282,7 @@ full = ["pygraphviz"]
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8,<3.10"
|
||||
content-hash = "51dccbf357cf2a087c60beef5bc118d0ef469b5c5c7d72794550c4ea3c318f28"
|
||||
content-hash = "70b6612a3502ea69dc5cd7d89d92245dc102a9acf8797562c2e26a1e074818a1"
|
||||
|
||||
[metadata.files]
|
||||
alabaster = [
|
||||
@@ -3193,8 +3193,8 @@ py-cpuinfo = [
|
||||
{file = "py-cpuinfo-8.0.0.tar.gz", hash = "sha256:5f269be0e08e33fd959de96b34cd4aeeeacac014dd8305f70eb28d06de2345c5"},
|
||||
]
|
||||
py-progress-tracker = [
|
||||
{file = "py-progress-tracker-0.1.0.tar.gz", hash = "sha256:ebda5b1e9d87a6cb8d9af02c625f372d89e7b6f52d0637cb8476db25ea55f5b4"},
|
||||
{file = "py_progress_tracker-0.1.0-py3-none-any.whl", hash = "sha256:5fcc8abaea1c46ea81fa2e99f2028cd988cf022db22bd7d70c684644534fccf9"},
|
||||
{file = "py-progress-tracker-0.3.3.tar.gz", hash = "sha256:344a312bc183f4ab4fca5deb5d7d8b94195d3e4c81a2aa929cefee63952ac4d2"},
|
||||
{file = "py_progress_tracker-0.3.3-py3-none-any.whl", hash = "sha256:f298f203c86c32539ba50ee955e8f7121e1095e0704436057f405e2527c7695c"},
|
||||
]
|
||||
pycodestyle = [
|
||||
{file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"},
|
||||
@@ -3458,8 +3458,8 @@ readme-renderer = [
|
||||
{file = "readme_renderer-32.0.tar.gz", hash = "sha256:b512beafa6798260c7d5af3e1b1f097e58bfcd9a575da7c4ddd5e037490a5b85"},
|
||||
]
|
||||
requests = [
|
||||
{file = "requests-2.27.0-py2.py3-none-any.whl", hash = "sha256:f71a09d7feba4a6b64ffd8e9d9bc60f9bf7d7e19fd0e04362acb1cfc2e3d98df"},
|
||||
{file = "requests-2.27.0.tar.gz", hash = "sha256:8e5643905bf20a308e25e4c1dd379117c09000bf8a82ebccc462cfb1b34a16b5"},
|
||||
{file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"},
|
||||
{file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"},
|
||||
]
|
||||
requests-toolbelt = [
|
||||
{file = "requests-toolbelt-0.9.1.tar.gz", hash = "sha256:968089d4584ad4ad7c171454f0a5c6dac23971e9472521ea3b6d49d610aa6fc0"},
|
||||
|
||||
@@ -53,7 +53,7 @@ scikit-learn = "1.0.1"
|
||||
pandas = "1.3.4"
|
||||
pip-audit = "^1.1.1"
|
||||
pytest-codeblocks = "^0.12.2"
|
||||
py-progress-tracker = "^0.1.0"
|
||||
py-progress-tracker = "^0.3.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
||||
@@ -21,8 +21,9 @@ if ! source "${DEV_VENV_PATH}/bin/activate"; then
|
||||
source "${DEV_VENV_PATH}/bin/activate"
|
||||
fi
|
||||
|
||||
cd /src/ && make sync_env
|
||||
cd /src/ && make setup_env
|
||||
|
||||
mkdir -p /tmp/keycache
|
||||
mkdir -p logs
|
||||
|
||||
initial_concrete_log=logs/$(date -u --iso-8601=seconds).concrete.log
|
||||
|
||||
Reference in New Issue
Block a user