From 94b9d83cbb485f209d2f8ea9d56b2c782da880bb Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 13 Dec 2021 16:36:08 +0100 Subject: [PATCH] feat: support 1D vectors in MatMul closes #948 --- concrete/numpy/compile.py | 2 +- concrete/numpy/np_mlir_converter.py | 2 +- concrete/numpy/tracing.py | 32 +++++++++++++++++++++++++++-- tests/numpy/test_compile.py | 15 ++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 20fb93d87..b152b1f1b 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -536,7 +536,7 @@ def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None: # This does not update the TLU input values to allow for proper table generation. # Thankfully we are not supposed to touch the op_graph beyond that point for node in list((nx_graph := op_graph.graph).nodes): - if isinstance(node, GenericFunction): + if isinstance(node, GenericFunction) and node.op_kind == "TLU": ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node) variable_input_indices = [ idx diff --git a/concrete/numpy/np_mlir_converter.py b/concrete/numpy/np_mlir_converter.py index f937e96cb..0ffe00327 100644 --- a/concrete/numpy/np_mlir_converter.py +++ b/concrete/numpy/np_mlir_converter.py @@ -79,7 +79,7 @@ class NPMLIRConverter(OPGraphConverter): additional_conversion_info["tables"] = { node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node)) for node in op_graph.graph.nodes() - if isinstance(node, GenericFunction) + if isinstance(node, GenericFunction) and node.op_kind == "TLU" } return additional_conversion_info diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index d96efd66a..32dd4f9ab 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -642,17 +642,45 @@ def _on_numpy_multiply(lhs, rhs): return lhs.__mul__(rhs) -def _on_numpy_matmul(lhs, rhs): +def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer): common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers( numpy.matmul, lhs, rhs ) assert_true(len(common_output_dtypes_and_shapes) == 1) + output_shape = common_output_dtypes_and_shapes[0][1] + + # TODO: https://github.com/zama-ai/concretefhe-internal/issues/1174 + # remove all the reshape logic once matmul supports more combinations of arguments + if isinstance(lhs_output := lhs.output, TensorValue) and isinstance( + rhs_output := rhs.output, TensorValue + ): + # Manage non 2D cases + if lhs_output.ndim == 1 and rhs_output.ndim == 1: + lhs = lhs.reshape((1, lhs_output.shape[0])) + rhs = rhs.reshape((rhs_output.shape[0], 1)) + elif lhs_output.ndim == 1: + # lhs is a vector, reshape to be 2D and give proper result + output_shape = lhs_output.shape + lhs = lhs.reshape((1, lhs_output.shape[0])) + elif rhs_output.ndim == 1: + # rhs is a vector, reshape to be 2D and give proper result + output_shape = rhs_output.shape + rhs = rhs.reshape((rhs_output.shape[0], 1)) + traced_computation = MatMul( [lhs.output, rhs.output], common_output_dtypes_and_shapes[0][0], ) - return NPTracer([lhs, rhs], traced_computation, output_idx=0) + + matmul_tracer = NPTracer([lhs, rhs], traced_computation, output_idx=0) + + # Return the reshaped result if vector reshaping for 2D matmul happened + if matmul_tracer.shape != output_shape: + if output_shape == (): + return matmul_tracer[0, 0] + return matmul_tracer.reshape(output_shape) + return matmul_tracer NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 49146f91f..b80fdc71f 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1338,6 +1338,21 @@ def test_compile_and_run_constant_dot_correctness( (1, 2), (0, 8), ), + pytest.param( + (2,), + (2,), + (0, 8), + ), + pytest.param( + (5, 5), + (5,), + (0, 4), + ), + pytest.param( + (5,), + (5, 5), + (0, 4), + ), ], ) def test_compile_and_run_matmul_correctness(