diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 5932109b21..d3daa70c40 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -542,10 +542,9 @@ class Linearizer(Kernel): assert offs is None, "not available if we aren't doing reduce" return acc # MULACC fusion. TODO: this is copied from Interpreted - if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL: - x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg) - if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: # noqa: E501 - x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg) + if x.op == ReduceOps.SUM: + if x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg) + if x.src[0].op == UnaryOps.CAST and x.src[0].src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg) values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src] ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC} if x.op in ops: diff --git a/tinygrad/device.py b/tinygrad/device.py index 0b87db53bf..d9f612ec54 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -213,8 +213,9 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret def _interpret_ast(ast:LazyOp) -> str: # TODO: shortcutted store won't work with strides if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0]) - if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) + if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM: + if ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) + if ast.src[0].op == UnaryOps.CAST and ast.src[0].src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src[0].src, ast.arg) if ast.op in BufferOps: if ast.op == ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"