Refactor AMD pm rules to remove handwritten bf16 bool alus (#7136)

* refactor pm rules

- remove unused handwritten methods
- refactor amd pm rules to fix bug with bool alu

* add bf16 bool alu tests

* add bf16 tests

* hotfix: make atol consistent
This commit is contained in:
ignaciosica
2024-10-17 22:00:46 -03:00
committed by GitHub
parent 534597e753
commit 8bcdd7c97d
2 changed files with 16 additions and 6 deletions

View File

@@ -46,6 +46,7 @@ class ht:
float64 = strat.floats(width=64, allow_subnormal=False)
float32 = strat.floats(width=32, allow_subnormal=False)
float16 = strat.floats(width=16, allow_subnormal=False)
bfloat16 = strat.floats(width=16, allow_subnormal=False)
uint8 = strat.integers(0, 255)
uint16 = strat.integers(0, 65535)
uint32 = strat.integers(0, 2**32-1)
@@ -60,7 +61,8 @@ def universal_test(a, b, dtype, op):
if not isinstance(op, tuple): op = (op, op)
tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype)))
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
if dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
elif dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
else: np.testing.assert_equal(tensor_value, numpy_value)
def universal_test_unary(a, dtype, op):
@@ -71,7 +73,7 @@ def universal_test_unary(a, dtype, op):
run_schedule(sched)
tensor_value = out.numpy()
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)))
if dtype in dtypes_float:
if dtype in (*dtypes_float, dtypes.bfloat16):
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
else: np.testing.assert_equal(tensor_value, numpy_value)
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
@@ -105,6 +107,10 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.float16, ht.float16, strat.sampled_from(binary_operations))
def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16, Device.DEFAULT), f"no bfloat16 on {Device.DEFAULT}")
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
def test_bfloat16(self, a, b, op): universal_test(a, b, dtypes.bfloat16, op)
@given(ht.float32, strat.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
@@ -112,6 +118,10 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.float16, strat.sampled_from(unary_operations))
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16, Device.DEFAULT), f"no bfloat16 on {Device.DEFAULT}")
@given(ht.bfloat16, strat.sampled_from(unary_operations))
def test_bfloat16_unary(self, a, op): universal_test_unary(a, dtypes.bfloat16, op)
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)

View File

@@ -390,8 +390,10 @@ class AMDRenderer(CStyleLanguage):
extra_matcher = PatternMatcher([
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
*[(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))]]) + extra_pm
(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
(UPat(UOps.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg))]) + extra_pm
def render_vector_prefix(self, dtype:DType) -> str:
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
@@ -415,8 +417,6 @@ struct hip_bfloat16 {
return *reinterpret_cast<float*>(&uval);
}
};
static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
""")
for dtype in dedup(uop.dtype for uop in uops if uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))