rewrite MUL CAST SUM to CAST MULACC

This commit is contained in:
chenyu
2024-01-04 09:42:22 -08:00
parent ab7dfd637b
commit 91665ef143
3 changed files with 9 additions and 3 deletions

View File

@@ -73,7 +73,8 @@ class TestLinearizer(unittest.TestCase):
assert num_ops <= 0, "more load or alu uops than needed"
def test_sum_acc_dtype(self):
for tensor_dtype, acc_dtype in ((dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Linearizer(a.lazydata.schedule()[-1].ast)
k.linearize()