From e41f77349f70238d67d5852caf2510911d32db0c Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 20 Aug 2021 09:53:52 +0200 Subject: [PATCH] refactor: make BaseDataType __eq__ abstract - update test files with dummy dtypes --- hdk/common/data_types/base.py | 6 +++++- tests/common/data_types/test_dtypes_helpers.py | 3 +++ tests/common/test_common_helpers.py | 14 +++++--------- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/hdk/common/data_types/base.py b/hdk/common/data_types/base.py index 13ef63fe8..834e75dc9 100644 --- a/hdk/common/data_types/base.py +++ b/hdk/common/data_types/base.py @@ -1,7 +1,11 @@ """File holding code to represent data types in a program.""" -from abc import ABC +from abc import ABC, abstractmethod class BaseDataType(ABC): """Base class to represent a data type.""" + + @abstractmethod + def __eq__(self, o: object) -> bool: + """No default implementation.""" diff --git a/tests/common/data_types/test_dtypes_helpers.py b/tests/common/data_types/test_dtypes_helpers.py index 1a4761330..805424c99 100644 --- a/tests/common/data_types/test_dtypes_helpers.py +++ b/tests/common/data_types/test_dtypes_helpers.py @@ -62,6 +62,9 @@ def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result: class UnsupportedDataType(BaseDataType): """Test helper class to represent an UnsupportedDataType""" + def __eq__(self, o: object) -> bool: + return isinstance(o, self.__class__) + @pytest.mark.parametrize( "dtype1,dtype2,expected_mixed_dtype", diff --git a/tests/common/test_common_helpers.py b/tests/common/test_common_helpers.py index c0c2aef0c..a0e076a45 100644 --- a/tests/common/test_common_helpers.py +++ b/tests/common/test_common_helpers.py @@ -5,7 +5,7 @@ from copy import deepcopy import pytest from hdk.common import check_op_graph_is_integer_program, is_a_power_of_2 -from hdk.common.data_types.base import BaseDataType +from hdk.common.data_types.floats import Float64 from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import EncryptedValue from hdk.hnumpy.tracing import trace_numpy_function @@ -29,10 +29,6 @@ def test_is_a_power_of_2(x, result): assert is_a_power_of_2(x) == result -class DummyNotInteger(BaseDataType): - """Dummy helper data type class""" - - def test_check_op_graph_is_integer_program(): """Test function for check_op_graph_is_integer_program""" @@ -50,7 +46,7 @@ def test_check_op_graph_is_integer_program(): assert len(offending_nodes) == 0 op_graph_copy = deepcopy(op_graph) - op_graph_copy.output_nodes[0].outputs[0].data_type = DummyNotInteger() + op_graph_copy.output_nodes[0].outputs[0].data_type = Float64 offending_nodes = [] assert not check_op_graph_is_integer_program(op_graph_copy) @@ -59,7 +55,7 @@ def test_check_op_graph_is_integer_program(): assert offending_nodes == [op_graph_copy.output_nodes[0]] op_graph_copy = deepcopy(op_graph) - op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger() + op_graph_copy.input_nodes[0].inputs[0].data_type = Float64 offending_nodes = [] assert not check_op_graph_is_integer_program(op_graph_copy) @@ -68,8 +64,8 @@ def test_check_op_graph_is_integer_program(): assert offending_nodes == [op_graph_copy.input_nodes[0]] op_graph_copy = deepcopy(op_graph) - op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger() - op_graph_copy.input_nodes[1].inputs[0].data_type = DummyNotInteger() + op_graph_copy.input_nodes[0].inputs[0].data_type = Float64 + op_graph_copy.input_nodes[1].inputs[0].data_type = Float64 offending_nodes = [] assert not check_op_graph_is_integer_program(op_graph_copy)