feat: support 1D vectors in MatMul

closes #948
This commit is contained in:
Arthur Meyre
2021-12-13 16:36:08 +01:00
parent 214eebb458
commit 94b9d83cbb
4 changed files with 47 additions and 4 deletions

View File

@@ -536,7 +536,7 @@ def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None:
# This does not update the TLU input values to allow for proper table generation.
# Thankfully we are not supposed to touch the op_graph beyond that point
for node in list((nx_graph := op_graph.graph).nodes):
if isinstance(node, GenericFunction):
if isinstance(node, GenericFunction) and node.op_kind == "TLU":
ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node)
variable_input_indices = [
idx

View File

@@ -79,7 +79,7 @@ class NPMLIRConverter(OPGraphConverter):
additional_conversion_info["tables"] = {
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
for node in op_graph.graph.nodes()
if isinstance(node, GenericFunction)
if isinstance(node, GenericFunction) and node.op_kind == "TLU"
}
return additional_conversion_info

View File

@@ -642,17 +642,45 @@ def _on_numpy_multiply(lhs, rhs):
return lhs.__mul__(rhs)
def _on_numpy_matmul(lhs, rhs):
def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer):
common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers(
numpy.matmul, lhs, rhs
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
output_shape = common_output_dtypes_and_shapes[0][1]
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/1174
# remove all the reshape logic once matmul supports more combinations of arguments
if isinstance(lhs_output := lhs.output, TensorValue) and isinstance(
rhs_output := rhs.output, TensorValue
):
# Manage non 2D cases
if lhs_output.ndim == 1 and rhs_output.ndim == 1:
lhs = lhs.reshape((1, lhs_output.shape[0]))
rhs = rhs.reshape((rhs_output.shape[0], 1))
elif lhs_output.ndim == 1:
# lhs is a vector, reshape to be 2D and give proper result
output_shape = lhs_output.shape
lhs = lhs.reshape((1, lhs_output.shape[0]))
elif rhs_output.ndim == 1:
# rhs is a vector, reshape to be 2D and give proper result
output_shape = rhs_output.shape
rhs = rhs.reshape((rhs_output.shape[0], 1))
traced_computation = MatMul(
[lhs.output, rhs.output],
common_output_dtypes_and_shapes[0][0],
)
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
matmul_tracer = NPTracer([lhs, rhs], traced_computation, output_idx=0)
# Return the reshaped result if vector reshaping for 2D matmul happened
if matmul_tracer.shape != output_shape:
if output_shape == ():
return matmul_tracer[0, 0]
return matmul_tracer.reshape(output_shape)
return matmul_tracer
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add

View File

@@ -1338,6 +1338,21 @@ def test_compile_and_run_constant_dot_correctness(
(1, 2),
(0, 8),
),
pytest.param(
(2,),
(2,),
(0, 8),
),
pytest.param(
(5, 5),
(5,),
(0, 4),
),
pytest.param(
(5,),
(5, 5),
(0, 4),
),
],
)
def test_compile_and_run_matmul_correctness(