mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
30 lines
872 B
Python
30 lines
872 B
Python
from mlir.dialects.linalg.opdsl.lang import *
|
|
|
|
T1 = TV.T1
|
|
T2 = TV.T2
|
|
|
|
Batch = S.Batch
|
|
|
|
|
|
@linalg_structured_op
|
|
def fhelinalg_conv_2d_nchw_fchw(
|
|
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
|
|
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
|
|
O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
|
|
strides=AttributeDef(S.SH, S.SW),
|
|
dilations=AttributeDef(S.DH, S.DW)):
|
|
"""Performs 2-D convolution.
|
|
|
|
Layout:
|
|
* Input: NCHW.
|
|
* Kernel: FCHW.
|
|
|
|
Numeric casting is performed on the operands to the inner multiply, promoting
|
|
them to the same data type as the accumulator/output.
|
|
"""
|
|
implements(ConvolutionOpInterface)
|
|
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
|
|
O[D.n, D.f, D.oh, D.ow] += cast(
|
|
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW
|
|
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
|