mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: make BaseDataType __eq__ abstract
- update test files with dummy dtypes
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
"""File holding code to represent data types in a program."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseDataType(ABC):
|
||||
"""Base class to represent a data type."""
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, o: object) -> bool:
|
||||
"""No default implementation."""
|
||||
|
||||
@@ -62,6 +62,9 @@ def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result:
|
||||
class UnsupportedDataType(BaseDataType):
|
||||
"""Test helper class to represent an UnsupportedDataType"""
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return isinstance(o, self.__class__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype1,dtype2,expected_mixed_dtype",
|
||||
|
||||
@@ -5,7 +5,7 @@ from copy import deepcopy
|
||||
import pytest
|
||||
|
||||
from hdk.common import check_op_graph_is_integer_program, is_a_power_of_2
|
||||
from hdk.common.data_types.base import BaseDataType
|
||||
from hdk.common.data_types.floats import Float64
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.hnumpy.tracing import trace_numpy_function
|
||||
@@ -29,10 +29,6 @@ def test_is_a_power_of_2(x, result):
|
||||
assert is_a_power_of_2(x) == result
|
||||
|
||||
|
||||
class DummyNotInteger(BaseDataType):
|
||||
"""Dummy helper data type class"""
|
||||
|
||||
|
||||
def test_check_op_graph_is_integer_program():
|
||||
"""Test function for check_op_graph_is_integer_program"""
|
||||
|
||||
@@ -50,7 +46,7 @@ def test_check_op_graph_is_integer_program():
|
||||
assert len(offending_nodes) == 0
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.output_nodes[0].outputs[0].data_type = DummyNotInteger()
|
||||
op_graph_copy.output_nodes[0].outputs[0].data_type = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
@@ -59,7 +55,7 @@ def test_check_op_graph_is_integer_program():
|
||||
assert offending_nodes == [op_graph_copy.output_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
@@ -68,8 +64,8 @@ def test_check_op_graph_is_integer_program():
|
||||
assert offending_nodes == [op_graph_copy.input_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
|
||||
op_graph_copy.input_nodes[1].inputs[0].data_type = DummyNotInteger()
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
|
||||
op_graph_copy.input_nodes[1].inputs[0].data_type = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
|
||||
Reference in New Issue
Block a user