rm emu related change

This commit is contained in:
b1tg
2026-02-21 23:39:49 +00:00
parent 89538d86e0
commit efa4763c22
2 changed files with 2 additions and 30 deletions

View File

@@ -169,16 +169,6 @@ class TestDTypeALU(unittest.TestCase):
def test_fp8e5m2fnuz(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2fnuz), from_storage_scalar(b, dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
@given(ht.fp8e4m3fnuz, ht.fp8e4m3fnuz, strat.sampled_from(binary_operations))
@Context(EMULATED_DTYPES="fp8e4m3fnuz")
def test_emulated_fp8e4m3fnuz(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e4m3fnuz), from_storage_scalar(b, dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
@given(ht.fp8e5m2fnuz, ht.fp8e5m2fnuz, strat.sampled_from(binary_operations))
@Context(EMULATED_DTYPES="fp8e5m2fnuz")
def test_emulated_fp8e5m2fnuz(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2fnuz), from_storage_scalar(b, dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
@given(ht.float32, strat.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
@@ -234,18 +224,6 @@ class TestDTypeALU(unittest.TestCase):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
@given(ht.fp8e4m3fnuz, strat.sampled_from(unary_operations))
@Context(EMULATED_DTYPES="fp8e4m3fnuz")
def test_emulated_fp8e4m3fnuz_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
@given(ht.fp8e5m2fnuz, strat.sampled_from(unary_operations))
@Context(EMULATED_DTYPES="fp8e5m2fnuz")
def test_emulated_fp8e5m2fnuz_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)

View File

@@ -14,7 +14,7 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
# *** helper functions for bit manipulation ***
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - (0 if d.scalar() in dtypes.fp8_fnuz else 1)
def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - 1
def exponent_mask(d:DType) -> int: return (1 << dtypes.finfo(d.scalar())[0]) - 1
# **** utils ****
@@ -389,10 +389,6 @@ def f2f(v, fr:DType, to:DType):
sign, nosign = shl((v & shl(1, fs-1)).cast(f2f_dt[to]), ts - fs), (v & (shl(1, fs-1) - 1)).cast(f2f_dt[to])
exp, norm = shr(nosign, fm), shl(nosign, tm - fm) + shl(tb - fb, tm)
nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
if fr in dtypes.fp8_fnuz:
fnuz_nan = sign.ne(0) & nosign.eq(0)
qnan = shl(shl(1, te) - 1, tm) | shl(1, tm - 1)
return fnuz_nan.where(qnan, sign | exp.eq(0).where(0, norm)).bitcast(to)
# fp8e4m3 has only one nan
is_nan = (nosign.eq(shl(1, fm + fe) - 1) if fr == dtypes.fp8e4m3 else exp.eq(shl(1, fe) - 1))
return (sign | exp.eq(0).where(0, is_nan.where(nan, norm))).bitcast(to)
@@ -404,14 +400,12 @@ def f2f(v, fr:DType, to:DType):
nan_mantissa = (shl(1, tm) - 1) if to == dtypes.fp8e4m3 else (shr(nosign, fm - tm) & (shl(1, tm) - 1))
nan = (sign | nan_mantissa | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
is_nan = (shr(v, fm) & (shl(1, fe) - 1)).eq(shl(1, fe) - 1)
if to in dtypes.fp8_fnuz: return is_nan.where(shl(1, ts - 1), underflow.where(0, sign.cast(f2f_dt[to]) | norm))
return is_nan.where(nan, sign.cast(f2f_dt[to]) | underflow.where(0, norm))
else: raise NotImplementedError(f"unsupported decomp {fr} -> {to}")
def f2f_clamp(val:UOp, dt:DType) -> UOp:
e, m = dtypes.finfo(dt)
if dt in dtypes.fp8_fnuz: max_exp, max_man = (1 << e) - 1, (1 << m) - 1
else: max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1)
max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1)
mx = val.const_like(2.0**(max_exp - exponent_bias(dt)) * (1.0 + max_man / (1 << m)))
sat = mx if dt in dtypes.fp8s else val.const_like(float('inf'))
# FIXME: CMPLT of nan is undefined