feat(tracing): update hnumpy to manage tensor types

- binary ops and constants support
This commit is contained in:
Arthur Meyre
2021-08-23 14:42:50 +02:00
parent 96b04b45e1
commit 66d0c8dd62
3 changed files with 70 additions and 8 deletions

View File

@@ -2,6 +2,7 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
import numpy
from zamalang import CompilerEngine
from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
@@ -20,6 +21,32 @@ from ..hnumpy.tracing import trace_numpy_function
from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
"""Compute the maximum value between two values which can be numpy classes (e.g. ndarray).
Args:
lhs (Any): lhs value to compute max from.
rhs (Any): rhs value to compute max from.
Returns:
Any: maximum scalar value between lhs and rhs.
"""
return numpy.maximum(lhs, rhs).max()
def numpy_min_func(lhs: Any, rhs: Any) -> Any:
"""Compute the minimum value between two values which can be numpy classes (e.g. ndarray).
Args:
lhs (Any): lhs value to compute min from.
rhs (Any): rhs value to compute min from.
Returns:
Any: minimum scalar value between lhs and rhs.
"""
return numpy.minimum(lhs, rhs).min()
def compile_numpy_function_into_op_graph(
function_to_trace: Callable,
function_parameters: Dict[str, BaseValue],
@@ -72,7 +99,12 @@ def compile_numpy_function_into_op_graph(
)
# Find bounds with the dataset
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
node_bounds = eval_op_graph_bounds_on_dataset(
op_graph,
dataset,
min_func=numpy_min_func,
max_func=numpy_max_func,
)
# Update the graph accordingly: after that, we have the compilable graph
op_graph.update_values_with_bounds(

View File

@@ -15,7 +15,7 @@ from ..common.data_types.dtypes_helpers import (
)
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
from ..common.values import BaseValue, ScalarValue
from ..common.values import BaseValue, ScalarValue, TensorValue
NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
@@ -110,11 +110,13 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) ->
"""
base_dtype: BaseDataType
assert isinstance(
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
), f"Unsupported constant data of type {type(constant_data)}"
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data)
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
# numpy
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype)
else:
# python
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
return base_dtype
@@ -139,11 +141,13 @@ def get_base_value_for_numpy_or_python_constant_data(
"""
constant_data_value: Callable[..., BaseValue]
assert isinstance(
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
), f"Unsupported constant data of type {type(constant_data)}"
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
if isinstance(constant_data, numpy.ndarray):
constant_data_value = partial(TensorValue, data_type=base_dtype, shape=constant_data.shape)
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
constant_data_value = partial(ScalarValue, data_type=base_dtype)
else:
constant_data_value = get_base_value_for_python_constant_data(constant_data)

View File

@@ -7,7 +7,7 @@ import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.representation import intermediate as ir
from hdk.common.values import ClearValue, EncryptedValue
from hdk.common.values import ClearTensor, ClearValue, EncryptedTensor, EncryptedValue
from hdk.hnumpy import tracing
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
@@ -114,6 +114,32 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
@pytest.mark.parametrize(
"tensor_constructor",
[
EncryptedTensor,
ClearTensor,
],
)
def test_hnumpy_tracing_tensor_constant(tensor_constructor):
"Test hnumpy tracing tensor constant"
def simple_add_tensor(x):
return x + numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)
op_graph = tracing.trace_numpy_function(
simple_add_tensor, {"x": tensor_constructor(Integer(32, True), shape=(2, 2))}
)
constant_inputs = [node for node in op_graph.graph.nodes() if isinstance(node, ir.Constant)]
assert len(constant_inputs) == 1
constant_input_data = constant_inputs[0].constant_data
assert (constant_input_data == numpy.array([[1, 2], [3, 4]], dtype=numpy.int32)).all()
assert op_graph.get_ordered_outputs()[0].outputs[0].shape == constant_input_data.shape
@pytest.mark.parametrize(
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
[