feat(nptracer): add dot tracing abilities

- remove no cover from Dot.label
- small refactor of BaseTracer to make _sanitize a class method
- small refactor of get_ufunc_numpy_output_dtype to manage funcs and ufuncs
- add function routing to NPTracer
- add dot tracing to NPTracer
- small refactor to get tracing functions for numpy funcs and ufuncs
This commit is contained in:
Arthur Meyre
2021-08-24 11:52:35 +02:00
parent 6d663ef63d
commit ff260b2cd2
6 changed files with 176 additions and 40 deletions

View File

@@ -1,11 +1,12 @@
"""Test file for hnumpy debugging functions"""
import numpy
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.common.extensions.table import LookupTable
from hdk.common.values import ClearValue, EncryptedValue
from hdk.common.values import ClearValue, EncryptedTensor, EncryptedValue
from hdk.hnumpy import tracing
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
@@ -178,6 +179,36 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
)
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
# pylint: disable=unnecessary-lambda
(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
},
"\n%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)",
),
# pylint: enable=unnecessary-lambda
],
)
def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
"Test hnumpy get_printable_graph and draw_graph on graphs with dot"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
)
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
# returning 23b), since they are replaced later by the real bitwidths computed on the
# dataset