mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Refactor AMD pm rules to remove handwritten bf16 bool alus (#7136)
* refactor pm rules - remove unused handwritten methods - refactor amd pm rules to fix bug with bool alu * add bf16 bool alu tests * add bf16 tests * hotfix: make atol consistent
This commit is contained in:
@@ -46,6 +46,7 @@ class ht:
|
||||
float64 = strat.floats(width=64, allow_subnormal=False)
|
||||
float32 = strat.floats(width=32, allow_subnormal=False)
|
||||
float16 = strat.floats(width=16, allow_subnormal=False)
|
||||
bfloat16 = strat.floats(width=16, allow_subnormal=False)
|
||||
uint8 = strat.integers(0, 255)
|
||||
uint16 = strat.integers(0, 65535)
|
||||
uint32 = strat.integers(0, 2**32-1)
|
||||
@@ -60,7 +61,8 @@ def universal_test(a, b, dtype, op):
|
||||
if not isinstance(op, tuple): op = (op, op)
|
||||
tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype)))
|
||||
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
|
||||
if dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
elif dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
|
||||
def universal_test_unary(a, dtype, op):
|
||||
@@ -71,7 +73,7 @@ def universal_test_unary(a, dtype, op):
|
||||
run_schedule(sched)
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)))
|
||||
if dtype in dtypes_float:
|
||||
if dtype in (*dtypes_float, dtypes.bfloat16):
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
|
||||
@@ -105,6 +107,10 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.float16, ht.float16, strat.sampled_from(binary_operations))
|
||||
def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16, Device.DEFAULT), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
|
||||
def test_bfloat16(self, a, b, op): universal_test(a, b, dtypes.bfloat16, op)
|
||||
|
||||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
@@ -112,6 +118,10 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.float16, strat.sampled_from(unary_operations))
|
||||
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16, Device.DEFAULT), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@given(ht.bfloat16, strat.sampled_from(unary_operations))
|
||||
def test_bfloat16_unary(self, a, op): universal_test_unary(a, dtypes.bfloat16, op)
|
||||
|
||||
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
|
||||
|
||||
|
||||
@@ -390,8 +390,10 @@ class AMDRenderer(CStyleLanguage):
|
||||
extra_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
*[(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))]]) + extra_pm
|
||||
(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
|
||||
(UPat(UOps.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg))]) + extra_pm
|
||||
|
||||
def render_vector_prefix(self, dtype:DType) -> str:
|
||||
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
||||
@@ -415,8 +417,6 @@ struct hip_bfloat16 {
|
||||
return *reinterpret_cast<float*>(&uval);
|
||||
}
|
||||
};
|
||||
static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
|
||||
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
|
||||
""")
|
||||
|
||||
for dtype in dedup(uop.dtype for uop in uops if uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))
|
||||
|
||||
Reference in New Issue
Block a user