diff --git a/compiler/include/zamalang-c/Dialect/HLFHELinalg.h b/compiler/include/zamalang-c/Dialect/HLFHELinalg.h new file mode 100644 index 000000000..7adebe0f1 --- /dev/null +++ b/compiler/include/zamalang-c/Dialect/HLFHELinalg.h @@ -0,0 +1,19 @@ +#ifndef ZAMALANG_C_DIALECT_HLFHELINALG_H +#define ZAMALANG_C_DIALECT_HLFHELINALG_H + +#include "mlir-c/IR.h" +#include "mlir-c/Registration.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LLVM.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHELinalg, hlfhelinalg); + +#ifdef __cplusplus +} +#endif + +#endif // ZAMALANG_C_DIALECT_HLFHELINALG_H diff --git a/compiler/lib/Bindings/Python/CMakeLists.txt b/compiler/lib/Bindings/Python/CMakeLists.txt index 6905142db..affe8b286 100644 --- a/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compiler/lib/Bindings/Python/CMakeLists.txt @@ -15,6 +15,7 @@ declare_mlir_python_extension(ZamalangBindingsPythonExtension.Core CompilerAPIModule.cpp EMBED_CAPI_LINK_LIBS ZAMALANGCAPIHLFHE + ZAMALANGCAPIHLFHELINALG ZAMALANGCAPISupport ) diff --git a/compiler/lib/Bindings/Python/ZamalangModule.cpp b/compiler/lib/Bindings/Python/ZamalangModule.cpp index 6ec2bed8f..27645122b 100644 --- a/compiler/lib/Bindings/Python/ZamalangModule.cpp +++ b/compiler/lib/Bindings/Python/ZamalangModule.cpp @@ -5,6 +5,7 @@ #include "mlir-c/Registration.h" #include "mlir/Bindings/Python/PybindAdaptors.h" #include "zamalang-c/Dialect/HLFHE.h" +#include "zamalang-c/Dialect/HLFHELinalg.h" #include "llvm-c/ErrorHandling.h" #include "llvm/Support/Signals.h" @@ -28,6 +29,9 @@ PYBIND11_MODULE(_zamalang, m) { MlirDialectHandle hlfhe = mlirGetDialectHandle__hlfhe__(); mlirDialectHandleRegisterDialect(hlfhe, context); mlirDialectHandleLoadDialect(hlfhe, context); + MlirDialectHandle hlfhelinalg = mlirGetDialectHandle__hlfhelinalg__(); + mlirDialectHandleRegisterDialect(hlfhelinalg, context); + mlirDialectHandleLoadDialect(hlfhelinalg, context); }, "Register Zamalang dialects on a PyMlirContext."); diff --git a/compiler/lib/CAPI/Dialect/CMakeLists.txt b/compiler/lib/CAPI/Dialect/CMakeLists.txt index 112e1c3c7..496981e3c 100644 --- a/compiler/lib/CAPI/Dialect/CMakeLists.txt +++ b/compiler/lib/CAPI/Dialect/CMakeLists.txt @@ -1,10 +1,2 @@ -set(LLVM_OPTIONAL_SOURCES HLFHE.cpp) - -add_mlir_public_c_api_library(ZAMALANGCAPIHLFHE - - HLFHE.cpp - - LINK_LIBS PUBLIC - MLIRCAPIIR - HLFHEDialect - ) +add_subdirectory(HLFHE) +add_subdirectory(HLFHELinalg) \ No newline at end of file diff --git a/compiler/lib/CAPI/Dialect/HLFHE/CMakeLists.txt b/compiler/lib/CAPI/Dialect/HLFHE/CMakeLists.txt new file mode 100644 index 000000000..112e1c3c7 --- /dev/null +++ b/compiler/lib/CAPI/Dialect/HLFHE/CMakeLists.txt @@ -0,0 +1,10 @@ +set(LLVM_OPTIONAL_SOURCES HLFHE.cpp) + +add_mlir_public_c_api_library(ZAMALANGCAPIHLFHE + + HLFHE.cpp + + LINK_LIBS PUBLIC + MLIRCAPIIR + HLFHEDialect + ) diff --git a/compiler/lib/CAPI/Dialect/HLFHE.cpp b/compiler/lib/CAPI/Dialect/HLFHE/HLFHE.cpp similarity index 100% rename from compiler/lib/CAPI/Dialect/HLFHE.cpp rename to compiler/lib/CAPI/Dialect/HLFHE/HLFHE.cpp diff --git a/compiler/lib/CAPI/Dialect/HLFHELinalg/CMakeLists.txt b/compiler/lib/CAPI/Dialect/HLFHELinalg/CMakeLists.txt new file mode 100644 index 000000000..1d41d3a56 --- /dev/null +++ b/compiler/lib/CAPI/Dialect/HLFHELinalg/CMakeLists.txt @@ -0,0 +1,10 @@ +set(LLVM_OPTIONAL_SOURCES HLFHELinalg.cpp) + +add_mlir_public_c_api_library(ZAMALANGCAPIHLFHELINALG + + HLFHELinalg.cpp + + LINK_LIBS PUBLIC + MLIRCAPIIR + HLFHELinalgDialect + ) diff --git a/compiler/lib/CAPI/Dialect/HLFHELinalg/HLFHELinalg.cpp b/compiler/lib/CAPI/Dialect/HLFHELinalg/HLFHELinalg.cpp new file mode 100644 index 000000000..6951f7933 --- /dev/null +++ b/compiler/lib/CAPI/Dialect/HLFHELinalg/HLFHELinalg.cpp @@ -0,0 +1,16 @@ +#include "zamalang-c/Dialect/HLFHELinalg.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Support.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h" + +using namespace mlir::zamalang::HLFHELinalg; + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHELinalg, hlfhelinalg, + HLFHELinalgDialect) diff --git a/compiler/tests/python/test_hlfhe_dialect.py b/compiler/tests/python/test_hlfhe_dialect.py index 311486c28..7c7d57c78 100644 --- a/compiler/tests/python/test_hlfhe_dialect.py +++ b/compiler/tests/python/test_hlfhe_dialect.py @@ -1,5 +1,5 @@ import pytest -from mlir.ir import Context +from mlir.ir import Context, RankedTensorType, Location from zamalang import register_dialects from zamalang.dialects import hlfhe @@ -12,6 +12,17 @@ def test_eint(width): assert eint.__str__() == f"!HLFHE.eint<{width}>" +@pytest.mark.parametrize("shape", [(1,), (2,), (1, 1), (1, 2), (2, 1), (3, 3, 3)]) +def test_eint_tensor(shape): + with Context() as ctx, Location.unknown(context=ctx): + register_dialects(ctx) + eint = hlfhe.EncryptedIntegerType.get(ctx, 3) + tensor = RankedTensorType.get(shape, eint) + assert ( + tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!HLFHE.eint<{3}>>" + ) + + @pytest.mark.parametrize("width", [0, 8, 10, 12]) def test_invalid_eint(width): ctx = Context() diff --git a/compiler/tests/python/test_round_trip.py b/compiler/tests/python/test_round_trip.py index c17a05e3f..a4a6c68a5 100644 --- a/compiler/tests/python/test_round_trip.py +++ b/compiler/tests/python/test_round_trip.py @@ -44,6 +44,15 @@ VALID_INPUTS = [ """, id="add_eint_int_cst", ), + pytest.param( + """ + func @main(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> + return %1: tensor<4x!HLFHE.eint<2>> + } + """, + id="add_eint_int_1D", + ), ] INVALID_INPUTS = [ @@ -56,6 +65,16 @@ INVALID_INPUTS = [ """, id="eint<0>", ), + pytest.param( + """ + func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint_int' op has the dimension #2 of the operand #1 incompatible with other operands, got 2 expect 1 or 3}} + %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x2x3x4x!HLFHE.eint<2>> + } + """, + id="incompatible dimensions", + ), ]