mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 20:55:02 -05:00
feat(tracing): update hnumpy to manage tensor types
- binary ops and constants support
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
from zamalang import CompilerEngine
|
||||
|
||||
from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
@@ -20,6 +21,32 @@ from ..hnumpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data
|
||||
|
||||
|
||||
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
|
||||
"""Compute the maximum value between two values which can be numpy classes (e.g. ndarray).
|
||||
|
||||
Args:
|
||||
lhs (Any): lhs value to compute max from.
|
||||
rhs (Any): rhs value to compute max from.
|
||||
|
||||
Returns:
|
||||
Any: maximum scalar value between lhs and rhs.
|
||||
"""
|
||||
return numpy.maximum(lhs, rhs).max()
|
||||
|
||||
|
||||
def numpy_min_func(lhs: Any, rhs: Any) -> Any:
|
||||
"""Compute the minimum value between two values which can be numpy classes (e.g. ndarray).
|
||||
|
||||
Args:
|
||||
lhs (Any): lhs value to compute min from.
|
||||
rhs (Any): rhs value to compute min from.
|
||||
|
||||
Returns:
|
||||
Any: minimum scalar value between lhs and rhs.
|
||||
"""
|
||||
return numpy.minimum(lhs, rhs).min()
|
||||
|
||||
|
||||
def compile_numpy_function_into_op_graph(
|
||||
function_to_trace: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
@@ -72,7 +99,12 @@ def compile_numpy_function_into_op_graph(
|
||||
)
|
||||
|
||||
# Find bounds with the dataset
|
||||
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
|
||||
node_bounds = eval_op_graph_bounds_on_dataset(
|
||||
op_graph,
|
||||
dataset,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
)
|
||||
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds(
|
||||
|
||||
@@ -15,7 +15,7 @@ from ..common.data_types.dtypes_helpers import (
|
||||
)
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
from ..common.values import BaseValue, ScalarValue
|
||||
from ..common.values import BaseValue, ScalarValue, TensorValue
|
||||
|
||||
NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
|
||||
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
|
||||
@@ -110,11 +110,13 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) ->
|
||||
"""
|
||||
base_dtype: BaseDataType
|
||||
assert isinstance(
|
||||
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data)
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
# numpy
|
||||
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype)
|
||||
else:
|
||||
# python
|
||||
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return base_dtype
|
||||
|
||||
@@ -139,11 +141,13 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
"""
|
||||
constant_data_value: Callable[..., BaseValue]
|
||||
assert isinstance(
|
||||
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
|
||||
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
|
||||
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
constant_data_value = partial(TensorValue, data_type=base_dtype, shape=constant_data.shape)
|
||||
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
constant_data_value = partial(ScalarValue, data_type=base_dtype)
|
||||
else:
|
||||
constant_data_value = get_base_value_for_python_constant_data(constant_data)
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.representation import intermediate as ir
|
||||
from hdk.common.values import ClearValue, EncryptedValue
|
||||
from hdk.common.values import ClearTensor, ClearValue, EncryptedTensor, EncryptedValue
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
|
||||
@@ -114,6 +114,32 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_constructor",
|
||||
[
|
||||
EncryptedTensor,
|
||||
ClearTensor,
|
||||
],
|
||||
)
|
||||
def test_hnumpy_tracing_tensor_constant(tensor_constructor):
|
||||
"Test hnumpy tracing tensor constant"
|
||||
|
||||
def simple_add_tensor(x):
|
||||
return x + numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
simple_add_tensor, {"x": tensor_constructor(Integer(32, True), shape=(2, 2))}
|
||||
)
|
||||
|
||||
constant_inputs = [node for node in op_graph.graph.nodes() if isinstance(node, ir.Constant)]
|
||||
assert len(constant_inputs) == 1
|
||||
|
||||
constant_input_data = constant_inputs[0].constant_data
|
||||
|
||||
assert (constant_input_data == numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)).all()
|
||||
assert op_graph.get_ordered_outputs()[0].outputs[0].shape == constant_input_data.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user