feat(mlir): conversion of dot node into MLIR

This commit is contained in:
youben11
2021-09-01 15:00:54 +01:00
committed by Ayoub Benaissa
parent f686ca535a
commit 31c1787af2
5 changed files with 84 additions and 19 deletions

View File

@@ -3,7 +3,7 @@ import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.mlir.converters import add, apply_lut, constant, mul, sub
from hdk.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
from hdk.common.values import ClearScalar, EncryptedScalar
@@ -21,7 +21,7 @@ class MockNode:
self.outputs = outputs
@pytest.mark.parametrize("converter", [add, sub, mul])
@pytest.mark.parametrize("converter", [add, sub, mul, dot])
def test_failing_converter(converter):
"""Test failing converter"""
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):

View File

@@ -2,6 +2,7 @@
# pylint: disable=no-name-in-module,no-member
import itertools
import numpy
import pytest
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
from zamalang import compiler
@@ -66,6 +67,11 @@ def lut(x):
return table[x]
def dot(x, y):
"""Test dot"""
return numpy.dot(x, y)
def datagen(*args):
"""Generate data from ranges"""
for prod in itertools.product(*args):
@@ -178,6 +184,22 @@ def datagen(*args):
},
(range(0, 8),),
),
(
dot,
{
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 8), range(0, 8)),
),
(
dot,
{
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 8), range(0, 8)),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges):