flip mulacc to save a line (#997)

This commit is contained in:
George Hotz
2023-06-17 16:47:55 -07:00
committed by GitHub
parent d2b837c1d9
commit c690eeaca9
4 changed files with 5 additions and 6 deletions

View File

@@ -1,6 +1,6 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
@@ -138,7 +138,6 @@ class AssemblyCodegen(Linearizer):
for i,sr in enumerate(out.subregs()):
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU and newvar is not None:
if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere
out = newreg(newvar) if newvar not in tor else tor[newvar]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:

View File

@@ -53,7 +53,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})"
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
}
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:

View File

@@ -60,7 +60,7 @@ def to_float4(x:List[Token]) -> Optional[Token]:
def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
assert all_same([len(x) for x in values]), f"all values are not the same length {values}"
# these use accumulators, we can only fold if the acc is a float4
idxs = get_grouped_float4_idxs(values[0]) if grouping_allowed else None
idxs = get_grouped_float4_idxs(values[-1]) if grouping_allowed else None
if idxs is not None:
new_idxs = []
new_values = []
@@ -341,7 +341,7 @@ class Linearizer:
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
# TODO: fold float4 into a single uop when possible.
if isinstance(x.op, (ReduceOps, FusedOps)):
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)]
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])

View File

@@ -29,7 +29,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)),
FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(y,z, flags=('fast',)), x, flags=('fast',)),
FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)),
}
def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str: