diff --git a/hdk/common/data_types/floats.py b/hdk/common/data_types/floats.py new file mode 100644 index 000000000..7021886e0 --- /dev/null +++ b/hdk/common/data_types/floats.py @@ -0,0 +1,24 @@ +"""This file holds the definitions for floating point types""" + +from . import base + + +class Float(base.BaseDataType): + """Class representing a float""" + + # bit_width is the total number of bits used to represent a floating point number, including + # sign bit, exponent and mantissa + bit_width: int + + def __init__(self, bit_width: int) -> None: + self.bit_width = bit_width + + def __repr__(self) -> str: + return f"{self.__class__.__name__}<{self.bit_width} bits>" + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) and self.bit_width == other.bit_width + + +Float32 = lambda: Float(32) +Float64 = lambda: Float(64) diff --git a/tests/common/data_types/test_floats.py b/tests/common/data_types/test_floats.py new file mode 100644 index 000000000..49a7ba380 --- /dev/null +++ b/tests/common/data_types/test_floats.py @@ -0,0 +1,50 @@ +"""Test file for float data types""" + + +import pytest + +from hdk.common.data_types.floats import Float, Float32, Float64 + + +@pytest.mark.parametrize( + "float_,expected_repr_str", + [ + pytest.param( + Float32(), + "Float<32 bits>", + id="Float32", + ), + pytest.param( + Float(32), + "Float<32 bits>", + id="32 bits Float", + ), + pytest.param( + Float64(), + "Float<64 bits>", + id="Float64", + ), + pytest.param( + Float(64), + "Float<64 bits>", + id="64 bits Float", + ), + ], +) +def test_floats_repr(float_: Float, expected_repr_str: str): + """Test float repr""" + assert float_.__repr__() == expected_repr_str + + +@pytest.mark.parametrize( + "float_1,float_2,expected_equal", + [ + pytest.param(Float32(), Float(32), True), + pytest.param(Float(64), Float32(), False), + pytest.param(Float64(), Float(64), True), + ], +) +def test_floats_eq(float_1: Float, float_2: Float, expected_equal: bool): + """Test float eq""" + assert expected_equal == (float_1 == float_2) + assert expected_equal == (float_2 == float_1)