mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove type check for LazyOp.src now it's always LazyOp (#2963)
* remove type check for LazyOp.src now it's always LazyOp also matched MULACC criteria between interpreted and compiled (that probably need to be refactored somewhere else) * disable that test
This commit is contained in:
@@ -542,10 +542,9 @@ class Linearizer(Kernel):
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: # noqa: E501
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM:
|
||||
if x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.src[0].op == UnaryOps.CAST and x.src[0].src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
if x.op in ops:
|
||||
|
||||
@@ -213,8 +213,9 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
|
||||
def _interpret_ast(ast:LazyOp) -> str:
|
||||
# TODO: shortcutted store won't work with strides
|
||||
if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
|
||||
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM:
|
||||
if ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
if ast.src[0].op == UnaryOps.CAST and ast.src[0].src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src[0].src, ast.arg)
|
||||
|
||||
if ast.op in BufferOps:
|
||||
if ast.op == ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
|
||||
|
||||
Reference in New Issue
Block a user