mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
dev(floats): add Float class to represent a floating point value
This commit is contained in:
24
hdk/common/data_types/floats.py
Normal file
24
hdk/common/data_types/floats.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""This file holds the definitions for floating point types"""
|
||||
|
||||
from . import base
|
||||
|
||||
|
||||
class Float(base.BaseDataType):
|
||||
"""Class representing a float"""
|
||||
|
||||
# bit_width is the total number of bits used to represent a floating point number, including
|
||||
# sign bit, exponent and mantissa
|
||||
bit_width: int
|
||||
|
||||
def __init__(self, bit_width: int) -> None:
|
||||
self.bit_width = bit_width
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}<{self.bit_width} bits>"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
|
||||
|
||||
|
||||
Float32 = lambda: Float(32)
|
||||
Float64 = lambda: Float(64)
|
||||
50
tests/common/data_types/test_floats.py
Normal file
50
tests/common/data_types/test_floats.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Test file for float data types"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.floats import Float, Float32, Float64
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"float_,expected_repr_str",
|
||||
[
|
||||
pytest.param(
|
||||
Float32(),
|
||||
"Float<32 bits>",
|
||||
id="Float32",
|
||||
),
|
||||
pytest.param(
|
||||
Float(32),
|
||||
"Float<32 bits>",
|
||||
id="32 bits Float",
|
||||
),
|
||||
pytest.param(
|
||||
Float64(),
|
||||
"Float<64 bits>",
|
||||
id="Float64",
|
||||
),
|
||||
pytest.param(
|
||||
Float(64),
|
||||
"Float<64 bits>",
|
||||
id="64 bits Float",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_floats_repr(float_: Float, expected_repr_str: str):
|
||||
"""Test float repr"""
|
||||
assert float_.__repr__() == expected_repr_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"float_1,float_2,expected_equal",
|
||||
[
|
||||
pytest.param(Float32(), Float(32), True),
|
||||
pytest.param(Float(64), Float32(), False),
|
||||
pytest.param(Float64(), Float(64), True),
|
||||
],
|
||||
)
|
||||
def test_floats_eq(float_1: Float, float_2: Float, expected_equal: bool):
|
||||
"""Test float eq"""
|
||||
assert expected_equal == (float_1 == float_2)
|
||||
assert expected_equal == (float_2 == float_1)
|
||||
Reference in New Issue
Block a user