mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(mlir): conversion of dot node into MLIR
This commit is contained in:
@@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub
|
||||
from hdk.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
|
||||
from hdk.common.values import ClearScalar, EncryptedScalar
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class MockNode:
|
||||
self.outputs = outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul])
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul, dot])
|
||||
def test_failing_converter(converter):
|
||||
"""Test failing converter"""
|
||||
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import itertools
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
|
||||
from zamalang import compiler
|
||||
@@ -66,6 +67,11 @@ def lut(x):
|
||||
return table[x]
|
||||
|
||||
|
||||
def dot(x, y):
|
||||
"""Test dot"""
|
||||
return numpy.dot(x, y)
|
||||
|
||||
|
||||
def datagen(*args):
|
||||
"""Generate data from ranges"""
|
||||
for prod in itertools.product(*args):
|
||||
@@ -178,6 +184,22 @@ def datagen(*args):
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
dot,
|
||||
{
|
||||
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 8), range(0, 8)),
|
||||
),
|
||||
(
|
||||
dot,
|
||||
{
|
||||
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 8), range(0, 8)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
|
||||
Reference in New Issue
Block a user