mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-19 08:54:26 -05:00
feat(nptracer): add dot tracing abilities
- remove no cover from Dot.label - small refactor of BaseTracer to make _sanitize a class method - small refactor of get_ufunc_numpy_output_dtype to manage funcs and ufuncs - add function routing to NPTracer - add dot tracing to NPTracer - small refactor to get tracing functions for numpy funcs and ufuncs
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
"""Test file for hnumpy debugging functions"""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.common.values import ClearValue, EncryptedValue
|
||||
from hdk.common.values import ClearValue, EncryptedTensor, EncryptedValue
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
|
||||
@@ -178,6 +179,36 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
# pylint: disable=unnecessary-lambda
|
||||
(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
|
||||
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
|
||||
},
|
||||
"\n%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)",
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
"Test hnumpy get_printable_graph and draw_graph on graphs with dot"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot {str_of_the_graph}"
|
||||
f"\n==================\nExpected {ref_graph_str}"
|
||||
f"\n==================\n"
|
||||
)
|
||||
|
||||
|
||||
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
|
||||
# returning 23b), since they are replaced later by the real bitwidths computed on the
|
||||
# dataset
|
||||
|
||||
Reference in New Issue
Block a user