mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(tracing): implement and test tracing of basic operations for tensors
This commit is contained in:
@@ -304,6 +304,35 @@ def _get_fun(function: numpy.ufunc):
|
||||
# We are populating NPTracer.UFUNC_ROUTING dynamically
|
||||
NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC}
|
||||
|
||||
# We are adding initial support for `np.array(...)` +,-,* `BaseTracer`
|
||||
# (note that this is not the proper complete handling of these functions)
|
||||
|
||||
|
||||
def _on_numpy_add(lhs, rhs):
|
||||
if isinstance(lhs, BaseTracer):
|
||||
return lhs.__add__(rhs)
|
||||
|
||||
return rhs.__radd__(lhs)
|
||||
|
||||
|
||||
def _on_numpy_subtract(lhs, rhs):
|
||||
if isinstance(lhs, BaseTracer):
|
||||
return lhs.__sub__(rhs)
|
||||
|
||||
return rhs.__rsub__(lhs)
|
||||
|
||||
|
||||
def _on_numpy_multiply(lhs, rhs):
|
||||
if isinstance(lhs, BaseTracer):
|
||||
return lhs.__mul__(rhs)
|
||||
|
||||
return rhs.__rmul__(lhs)
|
||||
|
||||
|
||||
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add
|
||||
NPTracer.UFUNC_ROUTING[numpy.subtract] = _on_numpy_subtract
|
||||
NPTracer.UFUNC_ROUTING[numpy.multiply] = _on_numpy_multiply
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import get_printable_graph
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
@@ -165,30 +166,82 @@ def test_numpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_constructor",
|
||||
[
|
||||
EncryptedTensor,
|
||||
ClearTensor,
|
||||
],
|
||||
)
|
||||
def test_numpy_tracing_tensor_constant(tensor_constructor):
|
||||
"Test numpy tracing tensor constant"
|
||||
def test_numpy_tracing_tensors():
|
||||
"Test numpy tracing tensors"
|
||||
|
||||
def simple_add_tensor(x):
|
||||
return x + numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)
|
||||
def all_operations(x):
|
||||
intermediate = x + numpy.array([[1, 2], [3, 4]])
|
||||
intermediate = numpy.array([[5, 6], [7, 8]]) + intermediate
|
||||
|
||||
intermediate = numpy.array([[100, 200], [300, 400]]) - intermediate
|
||||
intermediate = intermediate - numpy.array([[10, 20], [30, 40]])
|
||||
|
||||
intermediate = intermediate * numpy.array([[1, 2], [2, 1]])
|
||||
intermediate = numpy.array([[2, 1], [1, 2]]) * intermediate
|
||||
|
||||
return intermediate
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
simple_add_tensor, {"x": tensor_constructor(Integer(32, True), shape=(2, 2))}
|
||||
all_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
|
||||
)
|
||||
|
||||
constant_inputs = [node for node in op_graph.graph.nodes() if isinstance(node, ir.Constant)]
|
||||
assert len(constant_inputs) == 1
|
||||
expected = """
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(5, 6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(7, 4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(3, 8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(9, 2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(10, 1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(11, 0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip()
|
||||
|
||||
constant_input_data = constant_inputs[0].constant_data
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
assert (constant_input_data == numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)).all()
|
||||
assert op_graph.get_ordered_outputs()[0].outputs[0].shape == constant_input_data.shape
|
||||
|
||||
def test_numpy_explicit_tracing_tensors():
|
||||
"Test numpy tracing tensors using explicit operations"
|
||||
|
||||
def all_explicit_operations(x):
|
||||
intermediate = numpy.add(x, numpy.array([[1, 2], [3, 4]]))
|
||||
intermediate = numpy.add(numpy.array([[5, 6], [7, 8]]), intermediate)
|
||||
|
||||
intermediate = numpy.subtract(numpy.array([[100, 200], [300, 400]]), intermediate)
|
||||
intermediate = numpy.subtract(intermediate, numpy.array([[10, 20], [30, 40]]))
|
||||
|
||||
intermediate = numpy.multiply(intermediate, numpy.array([[1, 2], [2, 1]]))
|
||||
intermediate = numpy.multiply(numpy.array([[2, 1], [1, 2]]), intermediate)
|
||||
|
||||
return intermediate
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(5, 6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(7, 4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(3, 8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(9, 2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(10, 1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(11, 0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip()
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user