diff --git a/compiler/tests/python/test_round_trip.py b/compiler/tests/python/test_round_trip.py index 25730c755..aab3d9f75 100644 --- a/compiler/tests/python/test_round_trip.py +++ b/compiler/tests/python/test_round_trip.py @@ -17,15 +17,13 @@ VALID_INPUTS = [ } """, """ - func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, - %arg1: memref<2xi3>, - %arg2: memref>) + func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, + %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { - "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref>) -> () - return + %1 = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2> + return %1 : !HLFHE.eint<2> } - """, ]