feat: mimic the exact numpy behavior for matmul

This commit is contained in:
Umut
2022-02-28 15:06:04 +03:00
parent ed28639c57
commit b71cbc8ecb
8 changed files with 108 additions and 86 deletions

View File

@@ -628,7 +628,13 @@ class IntermediateNodeConverter:
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.node.inputs[0].is_clear:
assert isinstance(self.node.outputs[0], TensorValue)
if self.node.outputs[0].shape == ():
if self.node.inputs[0].is_clear:
preds = preds[::-1]
result = fhelinalg.Dot(resulting_type, *preds).result
elif self.node.inputs[0].is_clear:
result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result
else:
result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result

View File

@@ -630,30 +630,14 @@ class MatMul(IntermediateNode):
self,
inputs: Iterable[BaseValue],
output_dtype: BaseDataType,
output_shape: Tuple[int, ...],
) -> None:
super().__init__(inputs)
assert_true(len(self.inputs) == 2)
assert_true(
all(
isinstance(input_value, TensorValue) and input_value.ndim == 2
for input_value in self.inputs
),
f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)",
)
lhs = cast(TensorValue, self.inputs[0])
rhs = cast(TensorValue, self.inputs[1])
assert_true(
lhs.shape[1] == rhs.shape[0],
f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} is not supported",
)
output_shape = (lhs.shape[0], rhs.shape[1])
output_value = (
EncryptedTensor(dtype=output_dtype, shape=output_shape)
if (lhs.is_encrypted or rhs.is_encrypted)
if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted)
else ClearTensor(dtype=output_dtype, shape=output_shape)
)

View File

@@ -773,36 +773,12 @@ def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer):
assert_true(len(common_output_dtypes_and_shapes) == 1)
output_shape = common_output_dtypes_and_shapes[0][1]
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1174
# remove all the reshape logic once matmul supports more combinations of arguments
if isinstance(lhs_output := lhs.output, TensorValue) and isinstance(
rhs_output := rhs.output, TensorValue
):
# Manage non 2D cases
if lhs_output.ndim == 1 and rhs_output.ndim == 1:
lhs = lhs.reshape((1, lhs_output.shape[0]))
rhs = rhs.reshape((rhs_output.shape[0], 1))
elif lhs_output.ndim == 1:
# lhs is a vector, reshape to be 2D and give proper result
lhs = lhs.reshape((1, lhs_output.shape[0]))
elif rhs_output.ndim == 1:
# rhs is a vector, reshape to be 2D and give proper result
rhs = rhs.reshape((rhs_output.shape[0], 1))
traced_computation = MatMul(
[lhs.output, rhs.output],
common_output_dtypes_and_shapes[0][0],
output_shape,
)
matmul_tracer = NPTracer([lhs, rhs], traced_computation, output_idx=0)
# Return the reshaped result if vector reshaping for 2D matmul happened
if matmul_tracer.shape != output_shape:
if output_shape == ():
return matmul_tracer[0, 0]
return matmul_tracer.reshape(output_shape)
return matmul_tracer
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add

View File

@@ -1,13 +1,13 @@
Name Version License
Pillow 9.0.1 Historical Permission Notice and Disclaimer (HPND)
PyYAML 6.0 MIT License
concrete-compiler 0.3.1 BSD-3
concrete-compiler 0.4.0 BSD-3
cycler 0.11.0 BSD License
fonttools 4.29.1 MIT License
kiwisolver 1.3.2 BSD License
loguru 0.5.3 MIT License
matplotlib 3.5.1 Python Software Foundation License
networkx 2.7 BSD License
networkx 2.7.1 BSD License
numpy 1.22.2 BSD License
packaging 21.3 Apache Software License; BSD License
pygraphviz 1.9 BSD License

76
poetry.lock generated
View File

@@ -243,7 +243,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "concrete-compiler"
version = "0.3.1"
version = "0.4.0"
description = "Concrete Compiler"
category = "main"
optional = false
@@ -553,7 +553,7 @@ test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "ipyparallel"]
[[package]]
name = "ipython"
version = "8.1.0"
version = "8.1.1"
description = "IPython: Productive Interactive Computing"
category = "dev"
optional = false
@@ -725,7 +725,7 @@ test = ["codecov", "coverage", "ipykernel", "ipython", "mock", "mypy", "pre-comm
[[package]]
name = "jupyter-console"
version = "6.4.0"
version = "6.4.2"
description = "Jupyter terminal console"
category = "dev"
optional = false
@@ -975,7 +975,7 @@ testing = ["beautifulsoup4", "coverage", "docutils (>=0.17.0,<0.18.0)", "pytest
[[package]]
name = "nbclient"
version = "0.5.11"
version = "0.5.12"
description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
category = "dev"
optional = false
@@ -985,7 +985,7 @@ python-versions = ">=3.7.0"
jupyter-client = ">=6.1.5"
nbformat = ">=5.0"
nest-asyncio = "*"
traitlets = ">=4.2"
traitlets = ">=5.0.0"
[package.extras]
sphinx = ["Sphinx (>=1.7)", "sphinx-book-theme", "mock", "moto", "myst-parser"]
@@ -1081,7 +1081,7 @@ python-versions = ">=3.5"
[[package]]
name = "networkx"
version = "2.7"
version = "2.7.1"
description = "Python package for creating and manipulating graphs and networks"
category = "main"
optional = false
@@ -1372,7 +1372,7 @@ python-versions = "*"
[[package]]
name = "py-progress-tracker"
version = "0.4.2"
version = "0.4.6"
description = "A simple benchmarking library"
category = "dev"
optional = false
@@ -1680,11 +1680,11 @@ python-versions = "*"
[[package]]
name = "pywinpty"
version = "2.0.2"
version = "2.0.5"
description = "Pseudo terminal support for Windows from Python."
category = "dev"
optional = false
python-versions = ">=3.6"
python-versions = ">=3.7"
[[package]]
name = "pyyaml"
@@ -1744,7 +1744,7 @@ test = ["pytest (>=6.0.0)", "pytest-cov (>=3.0.0)", "pytest-qt"]
[[package]]
name = "readme-renderer"
version = "32.0"
version = "33.0"
description = "readme_renderer is a library for rendering \"readme\" descriptions for Warehouse"
category = "dev"
optional = false
@@ -1756,7 +1756,7 @@ docutils = ">=0.13.1"
Pygments = ">=2.5.1"
[package.extras]
md = ["cmarkgfm (>=0.5.0,<0.7.0)"]
md = ["cmarkgfm (>=0.8.0)"]
[[package]]
name = "requests"
@@ -2062,7 +2062,7 @@ python-versions = "*"
[[package]]
name = "terminado"
version = "0.13.1"
version = "0.13.2"
description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
category = "dev"
optional = false
@@ -2279,7 +2279,7 @@ full = ["pygraphviz"]
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.10"
content-hash = "89f3c912cef146d06a3da96e36347367e1703fa5670f3153637e1822c5e81897"
content-hash = "126b3e771561008efd62e7302c2bc0db6eb55879478553cfd7b598e42a5a0559"
[metadata.files]
alabaster = [
@@ -2430,12 +2430,12 @@ colorama = [
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
]
concrete-compiler = [
{file = "concrete_compiler-0.3.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a6f3d27e3246af55d3f8116c12bd63747691939ca6b0c97170d6d7eda5bd3bb7"},
{file = "concrete_compiler-0.3.1-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:69e77d45a5df39758bbd38c3fa154d479ad7855afdc06bb7f93c75424d00eae8"},
{file = "concrete_compiler-0.3.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:fc2d87aebf0c6772dc9e443194b8c3a363614c6fb1042a5e86a9f5f77a76e360"},
{file = "concrete_compiler-0.3.1-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:9adc23818a2d64d24e0ab94fd40938b6aaf5ae32c8ad6b562158bd55914ac319"},
{file = "concrete_compiler-0.3.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:90ff4fc19dea3f28d7f177ab53979be533443424534f7ee7cce2f0622b82eb58"},
{file = "concrete_compiler-0.3.1-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:6d21fe84c739e482e3d1578d1f567226047736e77346d57acc94de5a2b108251"},
{file = "concrete_compiler-0.4.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:382cea30b9e7805dbb6620d1d3e31f66a3346c7e8ff4cae30c507bb843c0e5b1"},
{file = "concrete_compiler-0.4.0-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:1681e7273d39c2ef8e970db3fc0039504b4cd1fdef788eb677fa2ada5f0a9487"},
{file = "concrete_compiler-0.4.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:8fec775022d91886aae0f4e19061b1423f2c9d092bb444422417ce342c5981a3"},
{file = "concrete_compiler-0.4.0-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:3b3f31145688e58a95a6d1e9604b42e80f03d827ba11014f7611d394bcd60e78"},
{file = "concrete_compiler-0.4.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:0ed0cd4355ef60710601ba7d7479e5bd3546764faeafc612e3e42f49d9c290bd"},
{file = "concrete_compiler-0.4.0-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:e25c760d4f08a39c0f0af895497e8da457e9db0df9cf8752e3735150fc8d3227"},
]
coverage = [
{file = "coverage-6.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9b27d894748475fa858f9597c0ee1d4829f44683f3813633aaf94b19cb5453cf"},
@@ -2614,8 +2614,8 @@ ipykernel = [
{file = "ipykernel-6.9.1.tar.gz", hash = "sha256:f95070a2dfd3147f8ab19f18ee46733310813758593745e07ec18fb08b409f1d"},
]
ipython = [
{file = "ipython-8.1.0-py3-none-any.whl", hash = "sha256:7bfeb6f298b2d7f3859c4f3e134082015cf34de90f89f5020e107a5a762ef6db"},
{file = "ipython-8.1.0.tar.gz", hash = "sha256:42c23e90b2deaae631266885de1656a517a1673d7e1db57e8eb3a4bb6cd5ce1b"},
{file = "ipython-8.1.1-py3-none-any.whl", hash = "sha256:6f56bfaeaa3247aa3b9cd3b8cbab3a9c0abf7428392f97b21902d12b2f42a381"},
{file = "ipython-8.1.1.tar.gz", hash = "sha256:8138762243c9b3a3ffcf70b37151a2a35c23d3a29f9743878c33624f4207be3d"},
]
ipython-genutils = [
{file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
@@ -2655,8 +2655,8 @@ jupyter-client = [
{file = "jupyter_client-7.1.2.tar.gz", hash = "sha256:4ea61033726c8e579edb55626d8ee2e6bf0a83158ddf3751b8dd46b2c5cd1e96"},
]
jupyter-console = [
{file = "jupyter_console-6.4.0-py3-none-any.whl", hash = "sha256:7799c4ea951e0e96ba8260575423cb323ea5a03fcf5503560fa3e15748869e27"},
{file = "jupyter_console-6.4.0.tar.gz", hash = "sha256:242248e1685039cd8bff2c2ecb7ce6c1546eb50ee3b08519729e6e881aec19c7"},
{file = "jupyter_console-6.4.2-py3-none-any.whl", hash = "sha256:1d52cf1a80f0c7accaa2b8c68ca3e6fa2311ec33ac9651d6cb6b9168cca1dad9"},
{file = "jupyter_console-6.4.2.tar.gz", hash = "sha256:fce5bccac926c690924168ad46cae33a7d78d643a7b60af0f260af25d38ecf26"},
]
jupyter-core = [
{file = "jupyter_core-4.9.2-py3-none-any.whl", hash = "sha256:f875e4d27e202590311d468fa55f90c575f201490bd0c18acabe4e318db4a46d"},
@@ -2933,8 +2933,8 @@ myst-parser = [
{file = "myst_parser-0.15.2-py3-none-any.whl", hash = "sha256:40124b6f27a4c42ac7f06b385e23a9dcd03d84801e9c7130b59b3729a554b1f9"},
]
nbclient = [
{file = "nbclient-0.5.11-py3-none-any.whl", hash = "sha256:03e857bea3012377289daa1e1c1651f4fc0295bcd109ccd36a337efcdbebaed7"},
{file = "nbclient-0.5.11.tar.gz", hash = "sha256:751516992f34b58172bad54eef1e4bf7e4f4460d58e255ca1a4e5c9649476007"},
{file = "nbclient-0.5.12-py3-none-any.whl", hash = "sha256:ff2d908024aaabb8864e5392c3517c76e17994b1f9330dda9b5284da9275c499"},
{file = "nbclient-0.5.12.tar.gz", hash = "sha256:0dd7ee6db59753563035f606421c3b558bd8c28b116d7e3ab8a0b4026cb44e38"},
]
nbconvert = [
{file = "nbconvert-6.4.2-py3-none-any.whl", hash = "sha256:7b006ae9979af56200e7fa3db39d9d12c99e811e8843b05dbe518e5b754bcb2e"},
@@ -2957,8 +2957,8 @@ nest-asyncio = [
{file = "nest_asyncio-1.5.4.tar.gz", hash = "sha256:f969f6013a16fadb4adcf09d11a68a4f617c6049d7af7ac2c676110169a63abd"},
]
networkx = [
{file = "networkx-2.7-py3-none-any.whl", hash = "sha256:836544e160f1b7ebf720c01667b7d2f5724c0424374aef2b665ed64dd14b0c7d"},
{file = "networkx-2.7.tar.gz", hash = "sha256:effb7d9cd5c36e1e0d33f42a3aef5badde5030535826a367d5cf608a170af515"},
{file = "networkx-2.7.1-py3-none-any.whl", hash = "sha256:011e85d277c89681e8fa661cf5ff0743443445049b0b68789ad55ef09340c6e0"},
{file = "networkx-2.7.1.tar.gz", hash = "sha256:d1194ba753e5eed07cdecd1d23c5cd7a3c772099bd8dbd2fea366788cf4de7ba"},
]
notebook = [
{file = "notebook-6.4.8-py3-none-any.whl", hash = "sha256:3e702fcc54b8ae597533c3864793b7a1e971dec9e112f67235828d8a798fd654"},
@@ -3138,8 +3138,8 @@ py-cpuinfo = [
{file = "py-cpuinfo-8.0.0.tar.gz", hash = "sha256:5f269be0e08e33fd959de96b34cd4aeeeacac014dd8305f70eb28d06de2345c5"},
]
py-progress-tracker = [
{file = "py-progress-tracker-0.4.2.tar.gz", hash = "sha256:92cbef419d923fe75ae3561312afbbfcd53082e720fb5bb4abf90c1299e7770f"},
{file = "py_progress_tracker-0.4.2-py3-none-any.whl", hash = "sha256:b8e58f6b6e41cfda777c8b7b4da571e278cea4539f2c22a361703ad20bcabf84"},
{file = "py-progress-tracker-0.4.6.tar.gz", hash = "sha256:32a20b035818b9c8af70b2c7ff1edd037396131e6c2c0c3dbf5ba720a841994c"},
{file = "py_progress_tracker-0.4.6-py3-none-any.whl", hash = "sha256:fb1fed8a4ed996be32af47802b046df8b4c42b5cce80c7d2e44ff3c15a056126"},
]
pycodestyle = [
{file = "pycodestyle-2.8.0-py2.py3-none-any.whl", hash = "sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20"},
@@ -3299,11 +3299,11 @@ pywin32-ctypes = [
{file = "pywin32_ctypes-0.2.0-py2.py3-none-any.whl", hash = "sha256:9dc2d991b3479cc2df15930958b674a48a227d5361d413827a4cfd0b5876fc98"},
]
pywinpty = [
{file = "pywinpty-2.0.2-cp310-none-win_amd64.whl", hash = "sha256:4b421379b407bf2f52a64a4c58f61deffe623b5add02d871acb290b771bb6227"},
{file = "pywinpty-2.0.2-cp37-none-win_amd64.whl", hash = "sha256:238b75fc456a6bc558761a89c9e6b3c8f2f54d79db03ae28997a68313c24b2ca"},
{file = "pywinpty-2.0.2-cp38-none-win_amd64.whl", hash = "sha256:344858a0b956fdc64a547d5e1980b0257b47f5433ed7cb89bf7b6268cb280c6c"},
{file = "pywinpty-2.0.2-cp39-none-win_amd64.whl", hash = "sha256:a4a066eaf2e30944d3028d946883ceb7883a499b53c4b89ca2d54bd7a4210550"},
{file = "pywinpty-2.0.2.tar.gz", hash = "sha256:20ec117183f79642eff555ce0dd1823f942618d65813fb6122d14b6e34b5d05a"},
{file = "pywinpty-2.0.5-cp310-none-win_amd64.whl", hash = "sha256:f86c76e2881c37e69678cbbf178109f8da1fa8584db24d58e1b9369b0276cfcb"},
{file = "pywinpty-2.0.5-cp37-none-win_amd64.whl", hash = "sha256:ff9b52f182650cfdf3db1b264a6fe0963eb9d996a7a1fa843ac406c1e32111f8"},
{file = "pywinpty-2.0.5-cp38-none-win_amd64.whl", hash = "sha256:651ee1467bd7eb6f64d44dbc954b7ab7d15ab6d8adacc4e13299692c67c5d5d2"},
{file = "pywinpty-2.0.5-cp39-none-win_amd64.whl", hash = "sha256:e59a508ae78374febada3e53b5bbc90b5ad07ae68cbfd72a2e965f9793ae04f3"},
{file = "pywinpty-2.0.5.tar.gz", hash = "sha256:e125d3f1804d8804952b13e33604ad2ca8b9b2cac92b27b521c005d1604794f8"},
]
pyyaml = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
@@ -3398,8 +3398,8 @@ qtpy = [
{file = "QtPy-2.0.1.tar.gz", hash = "sha256:adfd073ffbd2de81dc7aaa0b983499ef5c59c96adcfdcc9dea60d42ca885eb8f"},
]
readme-renderer = [
{file = "readme_renderer-32.0-py3-none-any.whl", hash = "sha256:a50a0f2123a4c1145ac6f420e1a348aafefcc9211c846e3d51df05fe3d865b7d"},
{file = "readme_renderer-32.0.tar.gz", hash = "sha256:b512beafa6798260c7d5af3e1b1f097e58bfcd9a575da7c4ddd5e037490a5b85"},
{file = "readme_renderer-33.0-py3-none-any.whl", hash = "sha256:f02cee0c4de9636b5a62b6be50c9742427ba1b956aad1d938bfb087d0d72ccdf"},
{file = "readme_renderer-33.0.tar.gz", hash = "sha256:e3b53bc84bd6af054e4cc1fe3567dc1ae19f554134221043a3f8c674e22209db"},
]
requests = [
{file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"},
@@ -3496,8 +3496,8 @@ termcolor = [
{file = "termcolor-1.1.0.tar.gz", hash = "sha256:1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b"},
]
terminado = [
{file = "terminado-0.13.1-py3-none-any.whl", hash = "sha256:f446b522b50a7aa68b5def0a02893978fb48cb82298b0ebdae13003c6ee6f198"},
{file = "terminado-0.13.1.tar.gz", hash = "sha256:5b82b5c6e991f0705a76f961f43262a7fb1e55b093c16dca83f16384a7f39b7b"},
{file = "terminado-0.13.2-py3-none-any.whl", hash = "sha256:d61f112f3beb7271d953d3934f056af185f6be0750303581fa1c511379a8a5d0"},
{file = "terminado-0.13.2.tar.gz", hash = "sha256:e6147a7ea31d150f9df4a26cedde3dbb2e011be269f89ff0267ae4157f3ae426"},
]
testpath = [
{file = "testpath-0.6.0-py3-none-any.whl", hash = "sha256:8ada9f80a2ac6fb0391aa7cdb1a7d11cfa8429f693eda83f74dde570fe6fa639"},

View File

@@ -45,7 +45,7 @@ pygraphviz = { version = "^1.7", optional = true }
Pillow = "^9.0.0"
loguru = "^0.5.3"
setuptools = "*"
concrete-compiler = "^0.3.1"
concrete-compiler = "^0.4.0"
torch = "^1.10.2"
[tool.poetry.extras]

View File

@@ -176,6 +176,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
ClearTensor(Integer(32, True), shape=(2, 3)),
],
Integer(32, True),
(3, 3),
),
[numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)],
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),

View File

@@ -1681,6 +1681,61 @@ def test_compile_and_run_constant_dot_correctness(
(3,),
(0, 3),
),
pytest.param(
(5,),
(4, 5, 3),
(0, 5),
),
pytest.param(
(4, 5, 3),
(3,),
(0, 5),
),
pytest.param(
(5,),
(2, 4, 5, 3),
(0, 5),
),
pytest.param(
(2, 4, 5, 3),
(3,),
(0, 5),
),
pytest.param(
(5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(2, 5, 4, 3),
(3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(1, 3, 2),
(0, 5),
),
pytest.param(
(1, 4, 3),
(5, 3, 2),
(0, 5),
),
pytest.param(
(5, 4, 3),
(2, 1, 3, 2),
(0, 5),
),
pytest.param(
(2, 1, 4, 3),
(5, 3, 2),
(0, 5),
),
],
)
def test_compile_and_run_matmul_correctness(