From b71cbc8ecbadee86c3f2ff3979e41e948263ade2 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 28 Feb 2022 15:06:04 +0300 Subject: [PATCH] feat: mimic the exact numpy behavior for matmul --- concrete/common/mlir/node_converter.py | 8 +- .../common/representation/intermediate.py | 20 +---- concrete/numpy/tracing.py | 28 +------ deps_licenses/licenses_linux_user.txt | 4 +- poetry.lock | 76 +++++++++---------- pyproject.toml | 2 +- .../representation/test_intermediate.py | 1 + tests/numpy/test_compile.py | 55 ++++++++++++++ 8 files changed, 108 insertions(+), 86 deletions(-) diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index 8266a1c79..e6f5462e0 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -628,7 +628,13 @@ class IntermediateNodeConverter: resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) preds = self.preds - if self.node.inputs[0].is_clear: + assert isinstance(self.node.outputs[0], TensorValue) + if self.node.outputs[0].shape == (): + if self.node.inputs[0].is_clear: + preds = preds[::-1] + result = fhelinalg.Dot(resulting_type, *preds).result + + elif self.node.inputs[0].is_clear: result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result else: result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 5faaed6b8..53048dff2 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -630,30 +630,14 @@ class MatMul(IntermediateNode): self, inputs: Iterable[BaseValue], output_dtype: BaseDataType, + output_shape: Tuple[int, ...], ) -> None: super().__init__(inputs) assert_true(len(self.inputs) == 2) - assert_true( - all( - isinstance(input_value, TensorValue) and input_value.ndim == 2 - for input_value in self.inputs - ), - f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)", - ) - - lhs = cast(TensorValue, self.inputs[0]) - rhs = cast(TensorValue, self.inputs[1]) - - assert_true( - lhs.shape[1] == rhs.shape[0], - f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} is not supported", - ) - - output_shape = (lhs.shape[0], rhs.shape[1]) output_value = ( EncryptedTensor(dtype=output_dtype, shape=output_shape) - if (lhs.is_encrypted or rhs.is_encrypted) + if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted) else ClearTensor(dtype=output_dtype, shape=output_shape) ) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 09a1858f0..17356537f 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -773,36 +773,12 @@ def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer): assert_true(len(common_output_dtypes_and_shapes) == 1) output_shape = common_output_dtypes_and_shapes[0][1] - - # TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1174 - # remove all the reshape logic once matmul supports more combinations of arguments - if isinstance(lhs_output := lhs.output, TensorValue) and isinstance( - rhs_output := rhs.output, TensorValue - ): - # Manage non 2D cases - if lhs_output.ndim == 1 and rhs_output.ndim == 1: - lhs = lhs.reshape((1, lhs_output.shape[0])) - rhs = rhs.reshape((rhs_output.shape[0], 1)) - elif lhs_output.ndim == 1: - # lhs is a vector, reshape to be 2D and give proper result - lhs = lhs.reshape((1, lhs_output.shape[0])) - elif rhs_output.ndim == 1: - # rhs is a vector, reshape to be 2D and give proper result - rhs = rhs.reshape((rhs_output.shape[0], 1)) - traced_computation = MatMul( [lhs.output, rhs.output], common_output_dtypes_and_shapes[0][0], + output_shape, ) - - matmul_tracer = NPTracer([lhs, rhs], traced_computation, output_idx=0) - - # Return the reshaped result if vector reshaping for 2D matmul happened - if matmul_tracer.shape != output_shape: - if output_shape == (): - return matmul_tracer[0, 0] - return matmul_tracer.reshape(output_shape) - return matmul_tracer + return NPTracer([lhs, rhs], traced_computation, output_idx=0) NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add diff --git a/deps_licenses/licenses_linux_user.txt b/deps_licenses/licenses_linux_user.txt index adc643a38..83c8b0804 100644 --- a/deps_licenses/licenses_linux_user.txt +++ b/deps_licenses/licenses_linux_user.txt @@ -1,13 +1,13 @@ Name Version License Pillow 9.0.1 Historical Permission Notice and Disclaimer (HPND) PyYAML 6.0 MIT License - concrete-compiler 0.3.1 BSD-3 + concrete-compiler 0.4.0 BSD-3 cycler 0.11.0 BSD License fonttools 4.29.1 MIT License kiwisolver 1.3.2 BSD License loguru 0.5.3 MIT License matplotlib 3.5.1 Python Software Foundation License - networkx 2.7 BSD License + networkx 2.7.1 BSD License numpy 1.22.2 BSD License packaging 21.3 Apache Software License; BSD License pygraphviz 1.9 BSD License diff --git a/poetry.lock b/poetry.lock index d9620f94e..e649e31b8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -243,7 +243,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "concrete-compiler" -version = "0.3.1" +version = "0.4.0" description = "Concrete Compiler" category = "main" optional = false @@ -553,7 +553,7 @@ test = ["pytest (!=5.3.4)", "pytest-cov", "flaky", "ipyparallel"] [[package]] name = "ipython" -version = "8.1.0" +version = "8.1.1" description = "IPython: Productive Interactive Computing" category = "dev" optional = false @@ -725,7 +725,7 @@ test = ["codecov", "coverage", "ipykernel", "ipython", "mock", "mypy", "pre-comm [[package]] name = "jupyter-console" -version = "6.4.0" +version = "6.4.2" description = "Jupyter terminal console" category = "dev" optional = false @@ -975,7 +975,7 @@ testing = ["beautifulsoup4", "coverage", "docutils (>=0.17.0,<0.18.0)", "pytest [[package]] name = "nbclient" -version = "0.5.11" +version = "0.5.12" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." category = "dev" optional = false @@ -985,7 +985,7 @@ python-versions = ">=3.7.0" jupyter-client = ">=6.1.5" nbformat = ">=5.0" nest-asyncio = "*" -traitlets = ">=4.2" +traitlets = ">=5.0.0" [package.extras] sphinx = ["Sphinx (>=1.7)", "sphinx-book-theme", "mock", "moto", "myst-parser"] @@ -1081,7 +1081,7 @@ python-versions = ">=3.5" [[package]] name = "networkx" -version = "2.7" +version = "2.7.1" description = "Python package for creating and manipulating graphs and networks" category = "main" optional = false @@ -1372,7 +1372,7 @@ python-versions = "*" [[package]] name = "py-progress-tracker" -version = "0.4.2" +version = "0.4.6" description = "A simple benchmarking library" category = "dev" optional = false @@ -1680,11 +1680,11 @@ python-versions = "*" [[package]] name = "pywinpty" -version = "2.0.2" +version = "2.0.5" description = "Pseudo terminal support for Windows from Python." category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [[package]] name = "pyyaml" @@ -1744,7 +1744,7 @@ test = ["pytest (>=6.0.0)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "readme-renderer" -version = "32.0" +version = "33.0" description = "readme_renderer is a library for rendering \"readme\" descriptions for Warehouse" category = "dev" optional = false @@ -1756,7 +1756,7 @@ docutils = ">=0.13.1" Pygments = ">=2.5.1" [package.extras] -md = ["cmarkgfm (>=0.5.0,<0.7.0)"] +md = ["cmarkgfm (>=0.8.0)"] [[package]] name = "requests" @@ -2062,7 +2062,7 @@ python-versions = "*" [[package]] name = "terminado" -version = "0.13.1" +version = "0.13.2" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." category = "dev" optional = false @@ -2279,7 +2279,7 @@ full = ["pygraphviz"] [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.10" -content-hash = "89f3c912cef146d06a3da96e36347367e1703fa5670f3153637e1822c5e81897" +content-hash = "126b3e771561008efd62e7302c2bc0db6eb55879478553cfd7b598e42a5a0559" [metadata.files] alabaster = [ @@ -2430,12 +2430,12 @@ colorama = [ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] concrete-compiler = [ - {file = "concrete_compiler-0.3.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a6f3d27e3246af55d3f8116c12bd63747691939ca6b0c97170d6d7eda5bd3bb7"}, - {file = "concrete_compiler-0.3.1-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:69e77d45a5df39758bbd38c3fa154d479ad7855afdc06bb7f93c75424d00eae8"}, - {file = "concrete_compiler-0.3.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:fc2d87aebf0c6772dc9e443194b8c3a363614c6fb1042a5e86a9f5f77a76e360"}, - {file = "concrete_compiler-0.3.1-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:9adc23818a2d64d24e0ab94fd40938b6aaf5ae32c8ad6b562158bd55914ac319"}, - {file = "concrete_compiler-0.3.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:90ff4fc19dea3f28d7f177ab53979be533443424534f7ee7cce2f0622b82eb58"}, - {file = "concrete_compiler-0.3.1-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:6d21fe84c739e482e3d1578d1f567226047736e77346d57acc94de5a2b108251"}, + {file = "concrete_compiler-0.4.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:382cea30b9e7805dbb6620d1d3e31f66a3346c7e8ff4cae30c507bb843c0e5b1"}, + {file = "concrete_compiler-0.4.0-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:1681e7273d39c2ef8e970db3fc0039504b4cd1fdef788eb677fa2ada5f0a9487"}, + {file = "concrete_compiler-0.4.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:8fec775022d91886aae0f4e19061b1423f2c9d092bb444422417ce342c5981a3"}, + {file = "concrete_compiler-0.4.0-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:3b3f31145688e58a95a6d1e9604b42e80f03d827ba11014f7611d394bcd60e78"}, + {file = "concrete_compiler-0.4.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:0ed0cd4355ef60710601ba7d7479e5bd3546764faeafc612e3e42f49d9c290bd"}, + {file = "concrete_compiler-0.4.0-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:e25c760d4f08a39c0f0af895497e8da457e9db0df9cf8752e3735150fc8d3227"}, ] coverage = [ {file = "coverage-6.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9b27d894748475fa858f9597c0ee1d4829f44683f3813633aaf94b19cb5453cf"}, @@ -2614,8 +2614,8 @@ ipykernel = [ {file = "ipykernel-6.9.1.tar.gz", hash = "sha256:f95070a2dfd3147f8ab19f18ee46733310813758593745e07ec18fb08b409f1d"}, ] ipython = [ - {file = "ipython-8.1.0-py3-none-any.whl", hash = "sha256:7bfeb6f298b2d7f3859c4f3e134082015cf34de90f89f5020e107a5a762ef6db"}, - {file = "ipython-8.1.0.tar.gz", hash = "sha256:42c23e90b2deaae631266885de1656a517a1673d7e1db57e8eb3a4bb6cd5ce1b"}, + {file = "ipython-8.1.1-py3-none-any.whl", hash = "sha256:6f56bfaeaa3247aa3b9cd3b8cbab3a9c0abf7428392f97b21902d12b2f42a381"}, + {file = "ipython-8.1.1.tar.gz", hash = "sha256:8138762243c9b3a3ffcf70b37151a2a35c23d3a29f9743878c33624f4207be3d"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, @@ -2655,8 +2655,8 @@ jupyter-client = [ {file = "jupyter_client-7.1.2.tar.gz", hash = "sha256:4ea61033726c8e579edb55626d8ee2e6bf0a83158ddf3751b8dd46b2c5cd1e96"}, ] jupyter-console = [ - {file = "jupyter_console-6.4.0-py3-none-any.whl", hash = "sha256:7799c4ea951e0e96ba8260575423cb323ea5a03fcf5503560fa3e15748869e27"}, - {file = "jupyter_console-6.4.0.tar.gz", hash = "sha256:242248e1685039cd8bff2c2ecb7ce6c1546eb50ee3b08519729e6e881aec19c7"}, + {file = "jupyter_console-6.4.2-py3-none-any.whl", hash = "sha256:1d52cf1a80f0c7accaa2b8c68ca3e6fa2311ec33ac9651d6cb6b9168cca1dad9"}, + {file = "jupyter_console-6.4.2.tar.gz", hash = "sha256:fce5bccac926c690924168ad46cae33a7d78d643a7b60af0f260af25d38ecf26"}, ] jupyter-core = [ {file = "jupyter_core-4.9.2-py3-none-any.whl", hash = "sha256:f875e4d27e202590311d468fa55f90c575f201490bd0c18acabe4e318db4a46d"}, @@ -2933,8 +2933,8 @@ myst-parser = [ {file = "myst_parser-0.15.2-py3-none-any.whl", hash = "sha256:40124b6f27a4c42ac7f06b385e23a9dcd03d84801e9c7130b59b3729a554b1f9"}, ] nbclient = [ - {file = "nbclient-0.5.11-py3-none-any.whl", hash = "sha256:03e857bea3012377289daa1e1c1651f4fc0295bcd109ccd36a337efcdbebaed7"}, - {file = "nbclient-0.5.11.tar.gz", hash = "sha256:751516992f34b58172bad54eef1e4bf7e4f4460d58e255ca1a4e5c9649476007"}, + {file = "nbclient-0.5.12-py3-none-any.whl", hash = "sha256:ff2d908024aaabb8864e5392c3517c76e17994b1f9330dda9b5284da9275c499"}, + {file = "nbclient-0.5.12.tar.gz", hash = "sha256:0dd7ee6db59753563035f606421c3b558bd8c28b116d7e3ab8a0b4026cb44e38"}, ] nbconvert = [ {file = "nbconvert-6.4.2-py3-none-any.whl", hash = "sha256:7b006ae9979af56200e7fa3db39d9d12c99e811e8843b05dbe518e5b754bcb2e"}, @@ -2957,8 +2957,8 @@ nest-asyncio = [ {file = "nest_asyncio-1.5.4.tar.gz", hash = "sha256:f969f6013a16fadb4adcf09d11a68a4f617c6049d7af7ac2c676110169a63abd"}, ] networkx = [ - {file = "networkx-2.7-py3-none-any.whl", hash = "sha256:836544e160f1b7ebf720c01667b7d2f5724c0424374aef2b665ed64dd14b0c7d"}, - {file = "networkx-2.7.tar.gz", hash = "sha256:effb7d9cd5c36e1e0d33f42a3aef5badde5030535826a367d5cf608a170af515"}, + {file = "networkx-2.7.1-py3-none-any.whl", hash = "sha256:011e85d277c89681e8fa661cf5ff0743443445049b0b68789ad55ef09340c6e0"}, + {file = "networkx-2.7.1.tar.gz", hash = "sha256:d1194ba753e5eed07cdecd1d23c5cd7a3c772099bd8dbd2fea366788cf4de7ba"}, ] notebook = [ {file = "notebook-6.4.8-py3-none-any.whl", hash = "sha256:3e702fcc54b8ae597533c3864793b7a1e971dec9e112f67235828d8a798fd654"}, @@ -3138,8 +3138,8 @@ py-cpuinfo = [ {file = "py-cpuinfo-8.0.0.tar.gz", hash = "sha256:5f269be0e08e33fd959de96b34cd4aeeeacac014dd8305f70eb28d06de2345c5"}, ] py-progress-tracker = [ - {file = "py-progress-tracker-0.4.2.tar.gz", hash = "sha256:92cbef419d923fe75ae3561312afbbfcd53082e720fb5bb4abf90c1299e7770f"}, - {file = "py_progress_tracker-0.4.2-py3-none-any.whl", hash = "sha256:b8e58f6b6e41cfda777c8b7b4da571e278cea4539f2c22a361703ad20bcabf84"}, + {file = "py-progress-tracker-0.4.6.tar.gz", hash = "sha256:32a20b035818b9c8af70b2c7ff1edd037396131e6c2c0c3dbf5ba720a841994c"}, + {file = "py_progress_tracker-0.4.6-py3-none-any.whl", hash = "sha256:fb1fed8a4ed996be32af47802b046df8b4c42b5cce80c7d2e44ff3c15a056126"}, ] pycodestyle = [ {file = "pycodestyle-2.8.0-py2.py3-none-any.whl", hash = "sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20"}, @@ -3299,11 +3299,11 @@ pywin32-ctypes = [ {file = "pywin32_ctypes-0.2.0-py2.py3-none-any.whl", hash = "sha256:9dc2d991b3479cc2df15930958b674a48a227d5361d413827a4cfd0b5876fc98"}, ] pywinpty = [ - {file = "pywinpty-2.0.2-cp310-none-win_amd64.whl", hash = "sha256:4b421379b407bf2f52a64a4c58f61deffe623b5add02d871acb290b771bb6227"}, - {file = "pywinpty-2.0.2-cp37-none-win_amd64.whl", hash = "sha256:238b75fc456a6bc558761a89c9e6b3c8f2f54d79db03ae28997a68313c24b2ca"}, - {file = "pywinpty-2.0.2-cp38-none-win_amd64.whl", hash = "sha256:344858a0b956fdc64a547d5e1980b0257b47f5433ed7cb89bf7b6268cb280c6c"}, - {file = "pywinpty-2.0.2-cp39-none-win_amd64.whl", hash = "sha256:a4a066eaf2e30944d3028d946883ceb7883a499b53c4b89ca2d54bd7a4210550"}, - {file = "pywinpty-2.0.2.tar.gz", hash = "sha256:20ec117183f79642eff555ce0dd1823f942618d65813fb6122d14b6e34b5d05a"}, + {file = "pywinpty-2.0.5-cp310-none-win_amd64.whl", hash = "sha256:f86c76e2881c37e69678cbbf178109f8da1fa8584db24d58e1b9369b0276cfcb"}, + {file = "pywinpty-2.0.5-cp37-none-win_amd64.whl", hash = "sha256:ff9b52f182650cfdf3db1b264a6fe0963eb9d996a7a1fa843ac406c1e32111f8"}, + {file = "pywinpty-2.0.5-cp38-none-win_amd64.whl", hash = "sha256:651ee1467bd7eb6f64d44dbc954b7ab7d15ab6d8adacc4e13299692c67c5d5d2"}, + {file = "pywinpty-2.0.5-cp39-none-win_amd64.whl", hash = "sha256:e59a508ae78374febada3e53b5bbc90b5ad07ae68cbfd72a2e965f9793ae04f3"}, + {file = "pywinpty-2.0.5.tar.gz", hash = "sha256:e125d3f1804d8804952b13e33604ad2ca8b9b2cac92b27b521c005d1604794f8"}, ] pyyaml = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, @@ -3398,8 +3398,8 @@ qtpy = [ {file = "QtPy-2.0.1.tar.gz", hash = "sha256:adfd073ffbd2de81dc7aaa0b983499ef5c59c96adcfdcc9dea60d42ca885eb8f"}, ] readme-renderer = [ - {file = "readme_renderer-32.0-py3-none-any.whl", hash = "sha256:a50a0f2123a4c1145ac6f420e1a348aafefcc9211c846e3d51df05fe3d865b7d"}, - {file = "readme_renderer-32.0.tar.gz", hash = "sha256:b512beafa6798260c7d5af3e1b1f097e58bfcd9a575da7c4ddd5e037490a5b85"}, + {file = "readme_renderer-33.0-py3-none-any.whl", hash = "sha256:f02cee0c4de9636b5a62b6be50c9742427ba1b956aad1d938bfb087d0d72ccdf"}, + {file = "readme_renderer-33.0.tar.gz", hash = "sha256:e3b53bc84bd6af054e4cc1fe3567dc1ae19f554134221043a3f8c674e22209db"}, ] requests = [ {file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"}, @@ -3496,8 +3496,8 @@ termcolor = [ {file = "termcolor-1.1.0.tar.gz", hash = "sha256:1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b"}, ] terminado = [ - {file = "terminado-0.13.1-py3-none-any.whl", hash = "sha256:f446b522b50a7aa68b5def0a02893978fb48cb82298b0ebdae13003c6ee6f198"}, - {file = "terminado-0.13.1.tar.gz", hash = "sha256:5b82b5c6e991f0705a76f961f43262a7fb1e55b093c16dca83f16384a7f39b7b"}, + {file = "terminado-0.13.2-py3-none-any.whl", hash = "sha256:d61f112f3beb7271d953d3934f056af185f6be0750303581fa1c511379a8a5d0"}, + {file = "terminado-0.13.2.tar.gz", hash = "sha256:e6147a7ea31d150f9df4a26cedde3dbb2e011be269f89ff0267ae4157f3ae426"}, ] testpath = [ {file = "testpath-0.6.0-py3-none-any.whl", hash = "sha256:8ada9f80a2ac6fb0391aa7cdb1a7d11cfa8429f693eda83f74dde570fe6fa639"}, diff --git a/pyproject.toml b/pyproject.toml index 6a9beb6db..995c2ec75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ pygraphviz = { version = "^1.7", optional = true } Pillow = "^9.0.0" loguru = "^0.5.3" setuptools = "*" -concrete-compiler = "^0.3.1" +concrete-compiler = "^0.4.0" torch = "^1.10.2" [tool.poetry.extras] diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 8a8a9e755..ca5c073a1 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -176,6 +176,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En ClearTensor(Integer(32, True), shape=(2, 3)), ], Integer(32, True), + (3, 3), ), [numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)], numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]), diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index ab50e6a62..0bc93878d 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1681,6 +1681,61 @@ def test_compile_and_run_constant_dot_correctness( (3,), (0, 3), ), + pytest.param( + (5,), + (4, 5, 3), + (0, 5), + ), + pytest.param( + (4, 5, 3), + (3,), + (0, 5), + ), + pytest.param( + (5,), + (2, 4, 5, 3), + (0, 5), + ), + pytest.param( + (2, 4, 5, 3), + (3,), + (0, 5), + ), + pytest.param( + (5, 4, 3), + (3, 2), + (0, 5), + ), + pytest.param( + (4, 3), + (5, 3, 2), + (0, 5), + ), + pytest.param( + (2, 5, 4, 3), + (3, 2), + (0, 5), + ), + pytest.param( + (5, 4, 3), + (1, 3, 2), + (0, 5), + ), + pytest.param( + (1, 4, 3), + (5, 3, 2), + (0, 5), + ), + pytest.param( + (5, 4, 3), + (2, 1, 3, 2), + (0, 5), + ), + pytest.param( + (2, 1, 4, 3), + (5, 3, 2), + (0, 5), + ), ], ) def test_compile_and_run_matmul_correctness(