mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
Add two passes related to the tiling of `HLFHELinalg.matmul_eint_int`
operations.
The `HLFHELinalgTilingMarker` pass takes a vector of tile sizes and
adds an integer array attribute "tile-sizes" to each instance of
`HLFHELinalg.matmul_eint_int`, e.g.,
"HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} :
(tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>>
The `HLFHELinalgTiling` performs the actual tiling of each
`HLFHELinalg.matmul_eint_int` operation marked with a "tile-sizes"
attribute. The tiling preserves the level of abstraction of
HLFHELinalg and is implemented as a perfect loop nest of SCF for loops
with a `HLFHELinalg.matmul_eint_int` in the body.
For example,
func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>)
-> tensor<4x2x!HLFHE.eint<6>>
{
%0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} :
(tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>>
return %0 : tensor<4x2x!HLFHE.eint<6>>
}
becomes:
func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>)
-> tensor<4x2x!HLFHE.eint<6>>
{
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%0 = "HLFHELinalg.zero"() : () -> tensor<4x2x!HLFHE.eint<6>>
%1 = scf.for %arg2 = %c0 to %c4 step %c2 iter_args(%arg3 = %0)
-> (tensor<4x2x!HLFHE.eint<6>>) {
%2 = scf.for %arg4 = %c0 to %c2 step %c2 iter_args(%arg5 = %arg3)
-> (tensor<4x2x!HLFHE.eint<6>>) {
%3 = scf.for %arg6 = %c0 to %c2 step %c2 iter_args(%arg7 = %arg5)
-> (tensor<4x2x!HLFHE.eint<6>>) {
%4 = tensor.extract_slice %arg0[%arg2, %arg4] [2, 2] [1, 1] :
tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>
%5 = tensor.extract_slice %arg1[%arg4, %arg6] [2, 2] [1, 1] :
tensor<2x2xi7> to tensor<2x2xi7>
%6 = tensor.extract_slice %arg7[%arg2, %arg6] [2, 2] [1, 1] :
tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>
%7 = "HLFHELinalg.matmul_eint_int"(%4, %5) :
(tensor<2x2x!HLFHE.eint<6>>, tensor<2x2xi7>)
-> tensor<2x2x!HLFHE.eint<6>>
%8 = "HLFHELinalg.add_eint"(%6, %7) :
(tensor<2x2x!HLFHE.eint<6>>, tensor<2x2x!HLFHE.eint<6>>)
-> tensor<2x2x!HLFHE.eint<6>>
%9 = tensor.insert_slice %8 into %arg7[%arg2, %arg6] [2, 2] [1, 1] :
tensor<2x2x!HLFHE.eint<6>> into tensor<4x2x!HLFHE.eint<6>>
scf.yield %9 : tensor<4x2x!HLFHE.eint<6>>
}
scf.yield %3 : tensor<4x2x!HLFHE.eint<6>>
}
scf.yield %2 : tensor<4x2x!HLFHE.eint<6>>
}
return %1 : tensor<4x2x!HLFHE.eint<6>>
}
Only full tiles are supported, i.e., the size of the dimensions of the
operands must be a multiple of the respective tile sizes.