feat(tracing-astype): add astype method on NPTracer

This commit is contained in:
Arthur Meyre
2021-08-09 17:28:30 +02:00
parent 9ef2154d51
commit c51c4bd17a
2 changed files with 122 additions and 0 deletions

View File

@@ -1,14 +1,44 @@
"""hnumpy tracing utilities"""
from typing import Callable, Dict
import numpy
from numpy.typing import DTypeLike
from ..common.data_types import BaseValue
from ..common.operator_graph import OPGraph
from ..common.representation import intermediate as ir
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from .np_dtypes_helpers import convert_numpy_dtype_to_common_dtype
class NPTracer(BaseTracer):
"""Tracer class for numpy operations"""
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
"""Support numpy astype feature, for now it only accepts a dtype and no additional
parameters, *args and **kwargs are accepted for interface compatibility only
Args:
numpy_dtype (DTypeLike): The object describing a numpy type
Returns:
NPTracer: The NPTracer representing the casting operation
"""
assert len(args) == 0, f"astype currently only supports tracing without *args, got {args}"
assert (
len(kwargs) == 0
), f"astype currently only supports tracing without **kwargs, got {kwargs}"
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
output_dtype = convert_numpy_dtype_to_common_dtype(numpy_dtype)
traced_computation = ir.ArbitraryFunction(
input_base_value=self.output,
arbitrary_func=normalized_numpy_dtype.type,
output_dtype=output_dtype,
)
output_tracer = NPTracer([self], traced_computation=traced_computation, output_index=0)
return output_tracer
def trace_numpy_function(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]

View File

@@ -1,8 +1,10 @@
"""Test file for hnumpy tracing"""
import networkx as nx
import numpy
import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.representation import intermediate as ir
@@ -109,3 +111,93 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
@pytest.mark.parametrize(
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
[
(
lambda x: x.astype(numpy.int32),
Integer(32, is_signed=True),
[
(14, numpy.int32(14)),
(1.5, numpy.int32(1)),
(2.0, numpy.int32(2)),
(-1.5, numpy.int32(-1)),
(2 ** 31 - 1, numpy.int32(2 ** 31 - 1)),
(-(2 ** 31), numpy.int32(-(2 ** 31))),
],
),
(
lambda x: x.astype(numpy.uint32),
Integer(32, is_signed=False),
[
(14, numpy.uint32(14)),
(1.5, numpy.uint32(1)),
(2.0, numpy.uint32(2)),
(2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)),
],
),
(
lambda x: x.astype(numpy.int64),
Integer(64, is_signed=True),
[
(14, numpy.int64(14)),
(1.5, numpy.int64(1)),
(2.0, numpy.int64(2)),
(-1.5, numpy.int64(-1)),
(2 ** 63 - 1, numpy.int64(2 ** 63 - 1)),
(-(2 ** 63), numpy.int64(-(2 ** 63))),
],
),
(
lambda x: x.astype(numpy.uint64),
Integer(64, is_signed=False),
[
(14, numpy.uint64(14)),
(1.5, numpy.uint64(1)),
(2.0, numpy.uint64(2)),
(2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)),
],
),
(
lambda x: x.astype(numpy.float64),
Float(64),
[
(14, numpy.float64(14.0)),
(1.5, numpy.float64(1.5)),
(2.0, numpy.float64(2.0)),
(-1.5, numpy.float64(-1.5)),
],
),
(
lambda x: x.astype(numpy.float32),
Float(32),
[
(14, numpy.float32(14.0)),
(1.5, numpy.float32(1.5)),
(2.0, numpy.float32(2.0)),
(-1.5, numpy.float32(-1.5)),
],
),
],
)
def test_tracing_astype(
function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples
):
"""Test function for NPTracer.astype"""
for input_, expected_output in input_and_expected_output_tuples:
input_value = (
EncryptedValue(Integer(64, is_signed=True))
if isinstance(input_, int)
else EncryptedValue(Float(64))
)
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
assert op_graph_expected_output_type == output_node.outputs[0].data_type
node_results = op_graph.evaluate({0: numpy.array(input_)})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output))
assert expected_output == evaluated_output