mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(tracing-astype): add astype method on NPTracer
This commit is contained in:
@@ -1,14 +1,44 @@
|
||||
"""hnumpy tracing utilities"""
|
||||
from typing import Callable, Dict
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from .np_dtypes_helpers import convert_numpy_dtype_to_common_dtype
|
||||
|
||||
|
||||
class NPTracer(BaseTracer):
|
||||
"""Tracer class for numpy operations"""
|
||||
|
||||
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
|
||||
"""Support numpy astype feature, for now it only accepts a dtype and no additional
|
||||
parameters, *args and **kwargs are accepted for interface compatibility only
|
||||
|
||||
Args:
|
||||
numpy_dtype (DTypeLike): The object describing a numpy type
|
||||
|
||||
Returns:
|
||||
NPTracer: The NPTracer representing the casting operation
|
||||
"""
|
||||
assert len(args) == 0, f"astype currently only supports tracing without *args, got {args}"
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"astype currently only supports tracing without **kwargs, got {kwargs}"
|
||||
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
output_dtype = convert_numpy_dtype_to_common_dtype(numpy_dtype)
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
input_base_value=self.output,
|
||||
arbitrary_func=normalized_numpy_dtype.type,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
output_tracer = NPTracer([self], traced_computation=traced_computation, output_index=0)
|
||||
return output_tracer
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Test file for hnumpy tracing"""
|
||||
|
||||
import networkx as nx
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.representation import intermediate as ir
|
||||
@@ -109,3 +111,93 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: x.astype(numpy.int32),
|
||||
Integer(32, is_signed=True),
|
||||
[
|
||||
(14, numpy.int32(14)),
|
||||
(1.5, numpy.int32(1)),
|
||||
(2.0, numpy.int32(2)),
|
||||
(-1.5, numpy.int32(-1)),
|
||||
(2 ** 31 - 1, numpy.int32(2 ** 31 - 1)),
|
||||
(-(2 ** 31), numpy.int32(-(2 ** 31))),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.uint32),
|
||||
Integer(32, is_signed=False),
|
||||
[
|
||||
(14, numpy.uint32(14)),
|
||||
(1.5, numpy.uint32(1)),
|
||||
(2.0, numpy.uint32(2)),
|
||||
(2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.int64),
|
||||
Integer(64, is_signed=True),
|
||||
[
|
||||
(14, numpy.int64(14)),
|
||||
(1.5, numpy.int64(1)),
|
||||
(2.0, numpy.int64(2)),
|
||||
(-1.5, numpy.int64(-1)),
|
||||
(2 ** 63 - 1, numpy.int64(2 ** 63 - 1)),
|
||||
(-(2 ** 63), numpy.int64(-(2 ** 63))),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.uint64),
|
||||
Integer(64, is_signed=False),
|
||||
[
|
||||
(14, numpy.uint64(14)),
|
||||
(1.5, numpy.uint64(1)),
|
||||
(2.0, numpy.uint64(2)),
|
||||
(2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.float64),
|
||||
Float(64),
|
||||
[
|
||||
(14, numpy.float64(14.0)),
|
||||
(1.5, numpy.float64(1.5)),
|
||||
(2.0, numpy.float64(2.0)),
|
||||
(-1.5, numpy.float64(-1.5)),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.float32),
|
||||
Float(32),
|
||||
[
|
||||
(14, numpy.float32(14.0)),
|
||||
(1.5, numpy.float32(1.5)),
|
||||
(2.0, numpy.float32(2.0)),
|
||||
(-1.5, numpy.float32(-1.5)),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_astype(
|
||||
function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples
|
||||
):
|
||||
"""Test function for NPTracer.astype"""
|
||||
for input_, expected_output in input_and_expected_output_tuples:
|
||||
input_value = (
|
||||
EncryptedValue(Integer(64, is_signed=True))
|
||||
if isinstance(input_, int)
|
||||
else EncryptedValue(Float(64))
|
||||
)
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
assert op_graph_expected_output_type == output_node.outputs[0].data_type
|
||||
|
||||
node_results = op_graph.evaluate({0: numpy.array(input_)})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output))
|
||||
assert expected_output == evaluated_output
|
||||
|
||||
Reference in New Issue
Block a user