diff --git a/docs/abstractions.py b/docs/abstractions.py index e7483eaf16..eb66b7f8c9 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -102,7 +102,7 @@ class LazyOp: class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto() class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto() class ReduceOps(Enum): SUM = auto(); MAX = auto() -class TernaryOps(Enum): MULACC = auto(); WHERE = auto() +class TernaryOps(Enum): WHERE = auto() class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto() Op = Union[UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps] diff --git a/docs/adding_new_accelerators.md b/docs/adding_new_accelerators.md index 79c4b32883..1c5354d926 100644 --- a/docs/adding_new_accelerators.md +++ b/docs/adding_new_accelerators.md @@ -12,7 +12,6 @@ reduce_op (SUM, MAX) # A -> B (smaller s binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size) load_op (EMPTY, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device) ternary_op (WHERE) # A, A, A -> A -ternary_op [[optional]] (MULACC) # A * A -> B ``` ## mlops diff --git a/test/test_uops.py b/test/test_uops.py index 31bce05bb9..48703932a3 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -87,7 +87,6 @@ class TestFloatUOps(TestUOps): def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a MUL does not fold, this maintains original MULACC code path + if all(x.uop is not UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) \ + and u.vin[1].arg is BinaryOps.ADD and u.vin[1].vin[0].arg is not BinaryOps.MUL: if DEBUG >= 4: print(f"removing PHI node {u}") del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)] # NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype @@ -491,20 +493,14 @@ class Linearizer(Kernel): if x.op in ReduceOps and not do_reduce: assert offs is None, "not available if we aren't doing reduce" return acc - # MULACC fusion. - if x.op == ReduceOps.SUM: - if x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg) - if (castop:=x.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL: - # MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC - x = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), x.arg) values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src] - ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC} + ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX} if x.op in ops: ret: List[UOp] = [] input_acc = acc[:] for val, off in zip(zip(*values), cast(List[int], offs)): - acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[x.op]) + acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)]) ret.append(acc[off]) for off in range(len(acc)): if input_acc[off] != acc[off]: diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index a455f5eda2..b004ab8d9d 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -105,7 +105,7 @@ def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint if u.uop is UOps.ENDLOOP: mults = mult_stack.pop(-1) if u.uop is UOps.ALU: - flops += (2 if u.arg is TernaryOps.MULACC else 1) * mults + flops += mults if u.uop is UOps.LOAD: assert u.dtype is not None mem += u.dtype.itemsize * mults diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4955ff513f..4b3d614133 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -14,7 +14,7 @@ from dataclasses import dataclass class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702 -class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 +class TernaryOps(Enum): WHERE = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); WAIT = auto() # noqa: E702 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0e10c8887f..1652b50c61 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple): BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", - TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" + TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" } # returns a str expression of the casted xs with the given type @@ -180,9 +180,6 @@ class OpenCLLanguage(CStyleLanguage): float4 = "(float4)" code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"} uses_vload = True - # NOTE: mad is used so the loads aren't reordered into the math on 845 - code_for_op = {**CStyleLanguage().code_for_op, - TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})" if dtypes.is_float(dtype) else f"(({a}*{b})+{c})"} type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" } def render_cast(self, x, var_dtype, bitcast=False) -> str: return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 05fab8122c..db5f09f424 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -23,7 +23,6 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), # noqa: E501 BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y), - TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS) if dtypes.is_float(var_dtype) else builder.add(builder.mul(x, y), z), # noqa: E501 TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z), } diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 429879eed4..b80bc7aad0 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -13,7 +13,6 @@ from tinygrad.codegen.kernel import LinearizerOptions def exec_alu(arg, dtype, p): # TODO: make this complete and correctly honor the dtypes # TODO: use this for constant folding - if arg == TernaryOps.MULACC: return p[0]*p[1]+p[2] if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2] if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan if arg == UnaryOps.EXP2: