remove type check for LazyOp.src now it's always LazyOp (#2963)

* remove type check for LazyOp.src now it's always LazyOp

also matched MULACC criteria between interpreted and compiled (that probably need to be refactored somewhere else)

* disable that test
This commit is contained in:
chenyu
2024-01-01 17:27:29 -05:00
committed by GitHub
parent c81ce9643d
commit fadaa2ec28
2 changed files with 6 additions and 6 deletions

View File

@@ -542,10 +542,9 @@ class Linearizer(Kernel):
assert offs is None, "not available if we aren't doing reduce"
return acc
# MULACC fusion. TODO: this is copied from Interpreted
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: # noqa: E501
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
if x.op == ReduceOps.SUM:
if x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
if x.src[0].op == UnaryOps.CAST and x.src[0].src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
if x.op in ops:

View File

@@ -213,8 +213,9 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
def _interpret_ast(ast:LazyOp) -> str:
# TODO: shortcutted store won't work with strides
if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM:
if ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if ast.src[0].op == UnaryOps.CAST and ast.src[0].src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src[0].src, ast.arg)
if ast.op in BufferOps:
if ast.op == ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"