refactor: make BaseDataType __eq__ abstract

- update test files with dummy dtypes
This commit is contained in:
Arthur Meyre
2021-08-20 09:53:52 +02:00
parent 7a0f11b1b0
commit e41f77349f
3 changed files with 13 additions and 10 deletions

View File

@@ -1,7 +1,11 @@
"""File holding code to represent data types in a program."""
from abc import ABC
from abc import ABC, abstractmethod
class BaseDataType(ABC):
"""Base class to represent a data type."""
@abstractmethod
def __eq__(self, o: object) -> bool:
"""No default implementation."""

View File

@@ -62,6 +62,9 @@ def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result:
class UnsupportedDataType(BaseDataType):
"""Test helper class to represent an UnsupportedDataType"""
def __eq__(self, o: object) -> bool:
return isinstance(o, self.__class__)
@pytest.mark.parametrize(
"dtype1,dtype2,expected_mixed_dtype",

View File

@@ -5,7 +5,7 @@ from copy import deepcopy
import pytest
from hdk.common import check_op_graph_is_integer_program, is_a_power_of_2
from hdk.common.data_types.base import BaseDataType
from hdk.common.data_types.floats import Float64
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.hnumpy.tracing import trace_numpy_function
@@ -29,10 +29,6 @@ def test_is_a_power_of_2(x, result):
assert is_a_power_of_2(x) == result
class DummyNotInteger(BaseDataType):
"""Dummy helper data type class"""
def test_check_op_graph_is_integer_program():
"""Test function for check_op_graph_is_integer_program"""
@@ -50,7 +46,7 @@ def test_check_op_graph_is_integer_program():
assert len(offending_nodes) == 0
op_graph_copy = deepcopy(op_graph)
op_graph_copy.output_nodes[0].outputs[0].data_type = DummyNotInteger()
op_graph_copy.output_nodes[0].outputs[0].data_type = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
@@ -59,7 +55,7 @@ def test_check_op_graph_is_integer_program():
assert offending_nodes == [op_graph_copy.output_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
@@ -68,8 +64,8 @@ def test_check_op_graph_is_integer_program():
assert offending_nodes == [op_graph_copy.input_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
op_graph_copy.input_nodes[1].inputs[0].data_type = DummyNotInteger()
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
op_graph_copy.input_nodes[1].inputs[0].data_type = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)