diff --git a/tinygrad/codegen/assembly.py b/tinygrad/codegen/assembly.py index a4f1c67240..0cff2d16ff 100644 --- a/tinygrad/codegen/assembly.py +++ b/tinygrad/codegen/assembly.py @@ -1,6 +1,6 @@ from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict from tinygrad.codegen.linearizer import Linearizer, UOps, Token -from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps +from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps from tinygrad.helpers import DType, dtypes, DEBUG from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode import functools @@ -138,7 +138,6 @@ class AssemblyCodegen(Linearizer): for i,sr in enumerate(out.subregs()): ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP)) elif uop == UOps.ALU and newvar is not None: - if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere out = newreg(newvar) if newvar not in tor else tor[newvar] # this is the only thing that can violate SSA if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]: diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index b514b90fb1..6b1042adc9 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -53,7 +53,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", - BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})" + BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})" } def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index ed52fa9f24..85c23d002f 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -60,7 +60,7 @@ def to_float4(x:List[Token]) -> Optional[Token]: def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True): assert all_same([len(x) for x in values]), f"all values are not the same length {values}" # these use accumulators, we can only fold if the acc is a float4 - idxs = get_grouped_float4_idxs(values[0]) if grouping_allowed else None + idxs = get_grouped_float4_idxs(values[-1]) if grouping_allowed else None if idxs is not None: new_idxs = [] new_values = [] @@ -341,7 +341,7 @@ class Linearizer: values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] # TODO: fold float4 into a single uop when possible. if isinstance(x.op, (ReduceOps, FusedOps)): - ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)] + ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)] else: ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)] ordered_ret: List[Optional[Token]] = [None]*len(values[0]) diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index bb47d78715..5688192caf 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -29,7 +29,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)), BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()), BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), - FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(y,z, flags=('fast',)), x, flags=('fast',)), + FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)), } def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str: