diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py index a0326b6a9..2dbe045d4 100644 --- a/hdk/hnumpy/compile.py +++ b/hdk/hnumpy/compile.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple +import numpy from zamalang import CompilerEngine from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset @@ -20,6 +21,32 @@ from ..hnumpy.tracing import trace_numpy_function from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data +def numpy_max_func(lhs: Any, rhs: Any) -> Any: + """Compute the maximum value between two values which can be numpy classes (e.g. ndarray). + + Args: + lhs (Any): lhs value to compute max from. + rhs (Any): rhs value to compute max from. + + Returns: + Any: maximum scalar value between lhs and rhs. + """ + return numpy.maximum(lhs, rhs).max() + + +def numpy_min_func(lhs: Any, rhs: Any) -> Any: + """Compute the minimum value between two values which can be numpy classes (e.g. ndarray). + + Args: + lhs (Any): lhs value to compute min from. + rhs (Any): rhs value to compute min from. + + Returns: + Any: minimum scalar value between lhs and rhs. + """ + return numpy.minimum(lhs, rhs).min() + + def compile_numpy_function_into_op_graph( function_to_trace: Callable, function_parameters: Dict[str, BaseValue], @@ -72,7 +99,12 @@ def compile_numpy_function_into_op_graph( ) # Find bounds with the dataset - node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset) + node_bounds = eval_op_graph_bounds_on_dataset( + op_graph, + dataset, + min_func=numpy_min_func, + max_func=numpy_max_func, + ) # Update the graph accordingly: after that, we have the compilable graph op_graph.update_values_with_bounds( diff --git a/hdk/hnumpy/np_dtypes_helpers.py b/hdk/hnumpy/np_dtypes_helpers.py index 0cf1792ae..820405005 100644 --- a/hdk/hnumpy/np_dtypes_helpers.py +++ b/hdk/hnumpy/np_dtypes_helpers.py @@ -15,7 +15,7 @@ from ..common.data_types.dtypes_helpers import ( ) from ..common.data_types.floats import Float from ..common.data_types.integers import Integer -from ..common.values import BaseValue, ScalarValue +from ..common.values import BaseValue, ScalarValue, TensorValue NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { numpy.dtype(numpy.int32): Integer(32, is_signed=True), @@ -110,11 +110,13 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> """ base_dtype: BaseDataType assert isinstance( - constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) + constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) ), f"Unsupported constant data of type {type(constant_data)}" - if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES): - base_dtype = convert_numpy_dtype_to_base_data_type(constant_data) + if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): + # numpy + base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype) else: + # python base_dtype = get_base_data_type_for_python_constant_data(constant_data) return base_dtype @@ -139,11 +141,13 @@ def get_base_value_for_numpy_or_python_constant_data( """ constant_data_value: Callable[..., BaseValue] assert isinstance( - constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) + constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) ), f"Unsupported constant data of type {type(constant_data)}" base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data) - if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES): + if isinstance(constant_data, numpy.ndarray): + constant_data_value = partial(TensorValue, data_type=base_dtype, shape=constant_data.shape) + elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES): constant_data_value = partial(ScalarValue, data_type=base_dtype) else: constant_data_value = get_base_value_for_python_constant_data(constant_data) diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index 3405d13bc..529b38fd5 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -7,7 +7,7 @@ import pytest from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer from hdk.common.representation import intermediate as ir -from hdk.common.values import ClearValue, EncryptedValue +from hdk.common.values import ClearTensor, ClearValue, EncryptedTensor, EncryptedValue from hdk.hnumpy import tracing OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul] @@ -114,6 +114,32 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers): assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) +@pytest.mark.parametrize( + "tensor_constructor", + [ + EncryptedTensor, + ClearTensor, + ], +) +def test_hnumpy_tracing_tensor_constant(tensor_constructor): + "Test hnumpy tracing tensor constant" + + def simple_add_tensor(x): + return x + numpy.array([[1, 2], [3, 4]], dtype=numpy.int32) + + op_graph = tracing.trace_numpy_function( + simple_add_tensor, {"x": tensor_constructor(Integer(32, True), shape=(2, 2))} + ) + + constant_inputs = [node for node in op_graph.graph.nodes() if isinstance(node, ir.Constant)] + assert len(constant_inputs) == 1 + + constant_input_data = constant_inputs[0].constant_data + + assert (constant_input_data == numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)).all() + assert op_graph.get_ordered_outputs()[0].outputs[0].shape == constant_input_data.shape + + @pytest.mark.parametrize( "function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples", [