|
|
|
|
@@ -5,7 +5,7 @@ from collections import defaultdict
|
|
|
|
|
from enum import Enum, auto
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv
|
|
|
|
|
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same
|
|
|
|
|
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
|
|
|
|
|
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
|
|
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
|
|
|
@@ -21,19 +21,13 @@ class UOps(Enum):
|
|
|
|
|
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
|
|
|
|
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@dataclass(eq=False)
|
|
|
|
|
class UOp:
|
|
|
|
|
uop: UOps
|
|
|
|
|
dtype: Optional[DType]
|
|
|
|
|
vin: Tuple[UOp, ...]
|
|
|
|
|
arg: Any
|
|
|
|
|
def __repr__(self): return f"{self.num:4d} {str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.num for x in self.vin]):32s} {self.arg}"
|
|
|
|
|
#def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str(self.vin):32s} {self.arg}"
|
|
|
|
|
|
|
|
|
|
# UOps are unique
|
|
|
|
|
num: int
|
|
|
|
|
def __hash__(self): return self.num
|
|
|
|
|
def __eq__(self, x): return self.num == x.num
|
|
|
|
|
def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
|
|
|
|
|
|
|
|
|
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
|
|
|
|
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)]
|
|
|
|
|
@@ -52,7 +46,7 @@ class Linearizer(Kernel):
|
|
|
|
|
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
|
|
|
|
|
|
|
|
|
# NOTE: the consts have to be be cached for deduping of downstream uops to work
|
|
|
|
|
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b)
|
|
|
|
|
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
|
|
|
|
|
|
|
|
|
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
|
|
|
|
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
|
|
|
|
@@ -215,7 +209,6 @@ class Linearizer(Kernel):
|
|
|
|
|
# set global/local size
|
|
|
|
|
self.global_size: Optional[List[int]] = None
|
|
|
|
|
self.local_size: Optional[List[int]] = None
|
|
|
|
|
global_loop_ctx: Tuple[UOp, ...] = tuple()
|
|
|
|
|
if self.dont_use_locals:
|
|
|
|
|
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
|
|
|
|
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
|
|
|
|
@@ -226,7 +219,7 @@ class Linearizer(Kernel):
|
|
|
|
|
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
|
|
|
|
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
|
|
|
|
else:
|
|
|
|
|
global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs)
|
|
|
|
|
render_loop(loop_global_idxs+loop_local_idxs)
|
|
|
|
|
|
|
|
|
|
# parse AST
|
|
|
|
|
loaded_buffers = {}
|
|
|
|
|
@@ -302,7 +295,7 @@ class Linearizer(Kernel):
|
|
|
|
|
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
|
|
|
|
|
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
|
|
|
|
for z in range(cast(DType, ret.dtype).sz):
|
|
|
|
|
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx)
|
|
|
|
|
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
|
|
|
|
|
i += wmma_sz[2]
|
|
|
|
|
else:
|
|
|
|
|
if locals_to_store:
|
|
|
|
|
@@ -314,7 +307,7 @@ class Linearizer(Kernel):
|
|
|
|
|
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
|
|
|
|
|
|
|
|
|
# run early AST (with reduce)
|
|
|
|
|
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx)
|
|
|
|
|
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
|
|
|
|
|
|
|
|
|
# end the reduce loop
|
|
|
|
|
self.load_cache.clear()
|
|
|
|
|
@@ -365,11 +358,54 @@ class Linearizer(Kernel):
|
|
|
|
|
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
|
|
|
|
|
|
|
|
|
# run late AST
|
|
|
|
|
val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx)
|
|
|
|
|
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
|
|
|
|
|
|
|
|
|
|
# store
|
|
|
|
|
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
|
|
|
|
|
|
|
|
|
# graph helper functions
|
|
|
|
|
def get_recursive_parents(x:List[UOp]) -> List[UOp]:
|
|
|
|
|
ret: Set[UOp] = set()
|
|
|
|
|
this_round: Set[UOp] = set(x)
|
|
|
|
|
while len(this_round):
|
|
|
|
|
ret = ret.union(this_round)
|
|
|
|
|
next_round: Set[UOp] = set()
|
|
|
|
|
for r in this_round: next_round = next_round.union(set(r.vin))
|
|
|
|
|
this_round = next_round
|
|
|
|
|
return list(ret)
|
|
|
|
|
|
|
|
|
|
def get_recursive_children(x:UOp) -> List[UOp]:
|
|
|
|
|
deps = set([x])
|
|
|
|
|
ssize = 0
|
|
|
|
|
while ssize != len(deps):
|
|
|
|
|
ssize = len(deps)
|
|
|
|
|
for u in self.uops:
|
|
|
|
|
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
|
|
|
|
deps.add(u)
|
|
|
|
|
return sorted(list(deps), key=self.uops.index) # get the last one
|
|
|
|
|
|
|
|
|
|
def replace_op(old:UOp, new:UOp):
|
|
|
|
|
for u in self.uops:
|
|
|
|
|
u.vin = tuple(new if x is old else x for x in u.vin)
|
|
|
|
|
self.uops.remove(old)
|
|
|
|
|
|
|
|
|
|
# uops optimization
|
|
|
|
|
changed_something = True
|
|
|
|
|
while changed_something:
|
|
|
|
|
changed_something = False
|
|
|
|
|
for u in self.uops:
|
|
|
|
|
if u.uop == UOps.PHI and len(u.vin) == 3:
|
|
|
|
|
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
|
|
|
|
|
# TODO: ADD becomes a MUL, MAX can just become nothing
|
|
|
|
|
if all(x.uop != UOps.LOOP for x in get_recursive_parents(list(u.vin[0:2]))) and u.vin[1].arg == BinaryOps.ADD:
|
|
|
|
|
if DEBUG >= 4: print(f"removing PHI node {u}")
|
|
|
|
|
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
|
|
|
|
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
|
|
|
|
loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
|
|
|
|
|
if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
|
|
|
|
|
replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u)))
|
|
|
|
|
changed_something = True
|
|
|
|
|
|
|
|
|
|
# (recursively) remove childless uops
|
|
|
|
|
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
|
|
|
|
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
|
|
|
|
@@ -382,32 +418,21 @@ class Linearizer(Kernel):
|
|
|
|
|
if len(nu) == len(self.uops): break
|
|
|
|
|
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
|
|
|
|
self.uops = nu
|
|
|
|
|
del nu
|
|
|
|
|
|
|
|
|
|
def get_recursive_deps(x:UOp) -> List[UOp]:
|
|
|
|
|
deps = set([x])
|
|
|
|
|
ssize = 0
|
|
|
|
|
while ssize != len(deps):
|
|
|
|
|
ssize = len(deps)
|
|
|
|
|
for u in self.uops:
|
|
|
|
|
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
|
|
|
|
deps.add(u)
|
|
|
|
|
return sorted(list(deps), key=lambda x: x.num)
|
|
|
|
|
|
|
|
|
|
# add END of loops after the last thing that (recursively) depends on them
|
|
|
|
|
# and END any if statements
|
|
|
|
|
# add UOps.END
|
|
|
|
|
for u in self.uops:
|
|
|
|
|
if u.uop == UOps.LOOP:
|
|
|
|
|
last_phi = self.uops.index(get_recursive_deps(u)[-1])
|
|
|
|
|
at_end = self.uops[last_phi+1:]
|
|
|
|
|
self.uops = self.uops[:last_phi+1]
|
|
|
|
|
self.uop(UOps.END, None, (u,), cachable=False)
|
|
|
|
|
self.uops += at_end
|
|
|
|
|
# add END of loops after the last thing that (recursively) depends on them
|
|
|
|
|
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(get_recursive_children(u)[-1])+1)
|
|
|
|
|
elif u.uop == UOps.IF:
|
|
|
|
|
# END any if statements at the end of the uops
|
|
|
|
|
self.uop(UOps.END, None, (u,), cachable=False)
|
|
|
|
|
|
|
|
|
|
# maybe graph the uops
|
|
|
|
|
if DEBUG >= 5:
|
|
|
|
|
for u in self.uops: print(u)
|
|
|
|
|
for u in self.uops:
|
|
|
|
|
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
|
|
|
|
|
if getenv("GRAPHUOPS"):
|
|
|
|
|
from tinygrad.graph import graph_uops
|
|
|
|
|
graph_uops(self.uops)
|
|
|
|
|
@@ -419,27 +444,32 @@ class Linearizer(Kernel):
|
|
|
|
|
self.applied_opts_cache = self.applied_opts[:]
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp:
|
|
|
|
|
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
|
|
|
|
|
key = (uop, dtype, vin, arg)
|
|
|
|
|
if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop
|
|
|
|
|
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype)
|
|
|
|
|
if uop == UOps.ALU:
|
|
|
|
|
# rewrites. NOTE: the rewritten NEG op is still around...
|
|
|
|
|
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable)
|
|
|
|
|
# constant folding
|
|
|
|
|
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype)
|
|
|
|
|
# zero folding
|
|
|
|
|
for x in [0,1]:
|
|
|
|
|
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
|
|
|
|
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
|
|
|
|
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
|
|
|
|
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
|
|
|
|
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
|
|
|
|
if simplify:
|
|
|
|
|
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
|
|
|
|
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
|
|
|
|
|
if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): return self.const(vin[0].arg, dtype, insert_before)
|
|
|
|
|
if uop == UOps.ALU:
|
|
|
|
|
# rewrites. NOTE: the rewritten NEG op is still around...
|
|
|
|
|
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before)
|
|
|
|
|
# constant folding
|
|
|
|
|
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
|
|
|
|
|
# zero folding
|
|
|
|
|
for x in [0,1]:
|
|
|
|
|
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
|
|
|
|
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
|
|
|
|
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
|
|
|
|
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
|
|
|
|
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
|
|
|
|
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
|
|
|
|
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
|
|
|
|
|
#if DEBUG >= 5: print(self.uops[-1])
|
|
|
|
|
if cachable: self.saved_exprs[key] = self.uops[-1]
|
|
|
|
|
return self.uops[-1]
|
|
|
|
|
ret = UOp(uop, dtype, vin, arg)
|
|
|
|
|
if insert_before is not None:
|
|
|
|
|
self.uops.insert(insert_before, ret)
|
|
|
|
|
else:
|
|
|
|
|
self.uops.append(ret)
|
|
|
|
|
if cachable: self.saved_exprs[key] = ret
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
|
|
|
|
|
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
|
|
|
|
|
|