mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
@@ -536,7 +536,7 @@ def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None:
|
||||
# This does not update the TLU input values to allow for proper table generation.
|
||||
# Thankfully we are not supposed to touch the op_graph beyond that point
|
||||
for node in list((nx_graph := op_graph.graph).nodes):
|
||||
if isinstance(node, GenericFunction):
|
||||
if isinstance(node, GenericFunction) and node.op_kind == "TLU":
|
||||
ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node)
|
||||
variable_input_indices = [
|
||||
idx
|
||||
|
||||
@@ -79,7 +79,7 @@ class NPMLIRConverter(OPGraphConverter):
|
||||
additional_conversion_info["tables"] = {
|
||||
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
|
||||
for node in op_graph.graph.nodes()
|
||||
if isinstance(node, GenericFunction)
|
||||
if isinstance(node, GenericFunction) and node.op_kind == "TLU"
|
||||
}
|
||||
|
||||
return additional_conversion_info
|
||||
|
||||
@@ -642,17 +642,45 @@ def _on_numpy_multiply(lhs, rhs):
|
||||
return lhs.__mul__(rhs)
|
||||
|
||||
|
||||
def _on_numpy_matmul(lhs, rhs):
|
||||
def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer):
|
||||
common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
numpy.matmul, lhs, rhs
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
output_shape = common_output_dtypes_and_shapes[0][1]
|
||||
|
||||
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/1174
|
||||
# remove all the reshape logic once matmul supports more combinations of arguments
|
||||
if isinstance(lhs_output := lhs.output, TensorValue) and isinstance(
|
||||
rhs_output := rhs.output, TensorValue
|
||||
):
|
||||
# Manage non 2D cases
|
||||
if lhs_output.ndim == 1 and rhs_output.ndim == 1:
|
||||
lhs = lhs.reshape((1, lhs_output.shape[0]))
|
||||
rhs = rhs.reshape((rhs_output.shape[0], 1))
|
||||
elif lhs_output.ndim == 1:
|
||||
# lhs is a vector, reshape to be 2D and give proper result
|
||||
output_shape = lhs_output.shape
|
||||
lhs = lhs.reshape((1, lhs_output.shape[0]))
|
||||
elif rhs_output.ndim == 1:
|
||||
# rhs is a vector, reshape to be 2D and give proper result
|
||||
output_shape = rhs_output.shape
|
||||
rhs = rhs.reshape((rhs_output.shape[0], 1))
|
||||
|
||||
traced_computation = MatMul(
|
||||
[lhs.output, rhs.output],
|
||||
common_output_dtypes_and_shapes[0][0],
|
||||
)
|
||||
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
|
||||
|
||||
matmul_tracer = NPTracer([lhs, rhs], traced_computation, output_idx=0)
|
||||
|
||||
# Return the reshaped result if vector reshaping for 2D matmul happened
|
||||
if matmul_tracer.shape != output_shape:
|
||||
if output_shape == ():
|
||||
return matmul_tracer[0, 0]
|
||||
return matmul_tracer.reshape(output_shape)
|
||||
return matmul_tracer
|
||||
|
||||
|
||||
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add
|
||||
|
||||
@@ -1338,6 +1338,21 @@ def test_compile_and_run_constant_dot_correctness(
|
||||
(1, 2),
|
||||
(0, 8),
|
||||
),
|
||||
pytest.param(
|
||||
(2,),
|
||||
(2,),
|
||||
(0, 8),
|
||||
),
|
||||
pytest.param(
|
||||
(5, 5),
|
||||
(5,),
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
(5,),
|
||||
(5, 5),
|
||||
(0, 4),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_matmul_correctness(
|
||||
|
||||
Reference in New Issue
Block a user