mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rm emu related change
This commit is contained in:
@@ -169,16 +169,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_fp8e5m2fnuz(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2fnuz), from_storage_scalar(b, dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
|
||||
|
||||
@given(ht.fp8e4m3fnuz, ht.fp8e4m3fnuz, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e4m3fnuz")
|
||||
def test_emulated_fp8e4m3fnuz(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3fnuz), from_storage_scalar(b, dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
|
||||
|
||||
@given(ht.fp8e5m2fnuz, ht.fp8e5m2fnuz, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e5m2fnuz")
|
||||
def test_emulated_fp8e5m2fnuz(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2fnuz), from_storage_scalar(b, dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
|
||||
|
||||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
@@ -234,18 +224,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
|
||||
|
||||
@given(ht.fp8e4m3fnuz, strat.sampled_from(unary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e4m3fnuz")
|
||||
def test_emulated_fp8e4m3fnuz_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3fnuz), dtypes.fp8e4m3fnuz, op)
|
||||
|
||||
@given(ht.fp8e5m2fnuz, strat.sampled_from(unary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e5m2fnuz")
|
||||
def test_emulated_fp8e5m2fnuz_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2fnuz), dtypes.fp8e5m2fnuz, op)
|
||||
|
||||
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
||||
|
||||
# *** helper functions for bit manipulation ***
|
||||
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
|
||||
def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - (0 if d.scalar() in dtypes.fp8_fnuz else 1)
|
||||
def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - 1
|
||||
def exponent_mask(d:DType) -> int: return (1 << dtypes.finfo(d.scalar())[0]) - 1
|
||||
|
||||
# **** utils ****
|
||||
@@ -389,10 +389,6 @@ def f2f(v, fr:DType, to:DType):
|
||||
sign, nosign = shl((v & shl(1, fs-1)).cast(f2f_dt[to]), ts - fs), (v & (shl(1, fs-1) - 1)).cast(f2f_dt[to])
|
||||
exp, norm = shr(nosign, fm), shl(nosign, tm - fm) + shl(tb - fb, tm)
|
||||
nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
|
||||
if fr in dtypes.fp8_fnuz:
|
||||
fnuz_nan = sign.ne(0) & nosign.eq(0)
|
||||
qnan = shl(shl(1, te) - 1, tm) | shl(1, tm - 1)
|
||||
return fnuz_nan.where(qnan, sign | exp.eq(0).where(0, norm)).bitcast(to)
|
||||
# fp8e4m3 has only one nan
|
||||
is_nan = (nosign.eq(shl(1, fm + fe) - 1) if fr == dtypes.fp8e4m3 else exp.eq(shl(1, fe) - 1))
|
||||
return (sign | exp.eq(0).where(0, is_nan.where(nan, norm))).bitcast(to)
|
||||
@@ -404,14 +400,12 @@ def f2f(v, fr:DType, to:DType):
|
||||
nan_mantissa = (shl(1, tm) - 1) if to == dtypes.fp8e4m3 else (shr(nosign, fm - tm) & (shl(1, tm) - 1))
|
||||
nan = (sign | nan_mantissa | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
|
||||
is_nan = (shr(v, fm) & (shl(1, fe) - 1)).eq(shl(1, fe) - 1)
|
||||
if to in dtypes.fp8_fnuz: return is_nan.where(shl(1, ts - 1), underflow.where(0, sign.cast(f2f_dt[to]) | norm))
|
||||
return is_nan.where(nan, sign.cast(f2f_dt[to]) | underflow.where(0, norm))
|
||||
else: raise NotImplementedError(f"unsupported decomp {fr} -> {to}")
|
||||
|
||||
def f2f_clamp(val:UOp, dt:DType) -> UOp:
|
||||
e, m = dtypes.finfo(dt)
|
||||
if dt in dtypes.fp8_fnuz: max_exp, max_man = (1 << e) - 1, (1 << m) - 1
|
||||
else: max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1)
|
||||
max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1)
|
||||
mx = val.const_like(2.0**(max_exp - exponent_bias(dt)) * (1.0 + max_man / (1 << m)))
|
||||
sat = mx if dt in dtypes.fp8s else val.const_like(float('inf'))
|
||||
# FIXME: CMPLT of nan is undefined
|
||||
|
||||
Reference in New Issue
Block a user