mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove MULACC (#3459)
* init * removed mulacc * is uoptimize the problem? * lol hax make work temporarily fix l8er * revert extra/ changes * clean up * flaky metal tests? * add back mulacc for metal * revert last commit * try skipping linearizer_failure tests * skip flammit tests... cuz tests all work locally * try narrow down exact linearizer failure test * try 2 * try 4 * generated code is the exact same wtf why CI fails * code for 15 and 17 are exact same with or without mulacc, this should pass * try only 1 failure * try garbage collecting lol... * try del variables lol * try gcing after del lol... * is diskcache the problem??? * try disabling opts cache idk * try remove hack * try disable github metal cache... * try CACHELEVEL=0 :D idk anymore * try increase newCommandQueueWithMaxCommandBufferCount_, im almost out of ideas... * revert * actually not a HACK * oops
This commit is contained in:
@@ -102,7 +102,7 @@ class LazyOp:
|
||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto()
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto()
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto()
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
|
||||
class TernaryOps(Enum): WHERE = auto()
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps]
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ reduce_op (SUM, MAX) # A -> B (smaller s
|
||||
binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size)
|
||||
load_op (EMPTY, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device)
|
||||
ternary_op (WHERE) # A, A, A -> A
|
||||
ternary_op [[optional]] (MULACC) # A * A -> B
|
||||
```
|
||||
|
||||
## mlops
|
||||
|
||||
@@ -87,7 +87,6 @@ class TestFloatUOps(TestUOps):
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
|
||||
# MOD isn't tested on floats
|
||||
|
||||
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
|
||||
def test_where(self):
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float), PtrDType(dtypes.float)))
|
||||
|
||||
|
||||
@@ -423,7 +423,9 @@ class Linearizer(Kernel):
|
||||
if u.uop is UOps.PHI and len(u.vin) == 3:
|
||||
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
|
||||
# TODO: ADD becomes a MUL, MAX can just become nothing
|
||||
if all(x.uop is not UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg is BinaryOps.ADD:
|
||||
# NOTE: ADD -> MUL does not fold, this maintains original MULACC code path
|
||||
if all(x.uop is not UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) \
|
||||
and u.vin[1].arg is BinaryOps.ADD and u.vin[1].vin[0].arg is not BinaryOps.MUL:
|
||||
if DEBUG >= 4: print(f"removing PHI node {u}")
|
||||
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
||||
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
||||
@@ -491,20 +493,14 @@ class Linearizer(Kernel):
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
# MULACC fusion.
|
||||
if x.op == ReduceOps.SUM:
|
||||
if x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
||||
if (castop:=x.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
|
||||
# MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
|
||||
x = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), x.arg)
|
||||
|
||||
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx, cache=cache) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
|
||||
if x.op in ops:
|
||||
ret: List[UOp] = []
|
||||
input_acc = acc[:]
|
||||
for val, off in zip(zip(*values), cast(List[int], offs)):
|
||||
acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[x.op])
|
||||
acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)])
|
||||
ret.append(acc[off])
|
||||
for off in range(len(acc)):
|
||||
if input_acc[off] != acc[off]:
|
||||
|
||||
@@ -105,7 +105,7 @@ def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint
|
||||
if u.uop is UOps.ENDLOOP:
|
||||
mults = mult_stack.pop(-1)
|
||||
if u.uop is UOps.ALU:
|
||||
flops += (2 if u.arg is TernaryOps.MULACC else 1) * mults
|
||||
flops += mults
|
||||
if u.uop is UOps.LOAD:
|
||||
assert u.dtype is not None
|
||||
mem += u.dtype.itemsize * mults
|
||||
|
||||
@@ -14,7 +14,7 @@ from dataclasses import dataclass
|
||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(Enum):
|
||||
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class TernaryOps(Enum): WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); WAIT = auto() # noqa: E702
|
||||
|
||||
@@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"
|
||||
}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
@@ -180,9 +180,6 @@ class OpenCLLanguage(CStyleLanguage):
|
||||
float4 = "(float4)"
|
||||
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
||||
uses_vload = True
|
||||
# NOTE: mad is used so the loads aren't reordered into the math on 845
|
||||
code_for_op = {**CStyleLanguage().code_for_op,
|
||||
TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})" if dtypes.is_float(dtype) else f"(({a}*{b})+{c})"}
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
@@ -23,7 +23,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), # noqa: E501
|
||||
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
|
||||
TernaryOps.MULACC: lambda builder, x, y, z, var_dtype: builder.fadd(builder.fmul(x, y, flags=MFLAGS), z, flags=MFLAGS) if dtypes.is_float(var_dtype) else builder.add(builder.mul(x, y), z), # noqa: E501
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z),
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from tinygrad.codegen.kernel import LinearizerOptions
|
||||
def exec_alu(arg, dtype, p):
|
||||
# TODO: make this complete and correctly honor the dtypes
|
||||
# TODO: use this for constant folding
|
||||
if arg == TernaryOps.MULACC: return p[0]*p[1]+p[2]
|
||||
if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2]
|
||||
if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
|
||||
if arg == UnaryOps.EXP2:
|
||||
|
||||
Reference in New Issue
Block a user