mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
flip mulacc to save a line (#997)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
|
||||
from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps
|
||||
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
|
||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
@@ -138,7 +138,6 @@ class AssemblyCodegen(Linearizer):
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == UOps.ALU and newvar is not None:
|
||||
if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere
|
||||
out = newreg(newvar) if newvar not in tor else tor[newvar]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:
|
||||
|
||||
@@ -53,7 +53,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
||||
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})"
|
||||
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
|
||||
}
|
||||
|
||||
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
|
||||
@@ -60,7 +60,7 @@ def to_float4(x:List[Token]) -> Optional[Token]:
|
||||
def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
|
||||
assert all_same([len(x) for x in values]), f"all values are not the same length {values}"
|
||||
# these use accumulators, we can only fold if the acc is a float4
|
||||
idxs = get_grouped_float4_idxs(values[0]) if grouping_allowed else None
|
||||
idxs = get_grouped_float4_idxs(values[-1]) if grouping_allowed else None
|
||||
if idxs is not None:
|
||||
new_idxs = []
|
||||
new_values = []
|
||||
@@ -341,7 +341,7 @@ class Linearizer:
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
# TODO: fold float4 into a single uop when possible.
|
||||
if isinstance(x.op, (ReduceOps, FusedOps)):
|
||||
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)]
|
||||
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
|
||||
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
|
||||
|
||||
@@ -29,7 +29,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()),
|
||||
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)),
|
||||
FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(y,z, flags=('fast',)), x, flags=('fast',)),
|
||||
FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)),
|
||||
}
|
||||
|
||||
def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
||||
|
||||
Reference in New Issue
Block a user