Refactor to UOpGraph class (#3566)

* Refactor to UOpGraph class

* fix test
This commit is contained in:
George Hotz
2024-03-01 15:14:48 -08:00
committed by GitHub
parent b7e555f6c0
commit 6b29c70b3d
5 changed files with 253 additions and 235 deletions

View File

@@ -13,6 +13,7 @@ from tinygrad.features.jit import CacheCollector
from tinygrad.realize import create_schedule, run_schedule
from tinygrad.helpers import prod, Context
from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
class TestLinearizer(unittest.TestCase):
@@ -54,8 +55,8 @@ class TestLinearizer(unittest.TestCase):
lin = Linearizer(ast)
lin.linearize()
a_bufs = [u.uop for u in lin.uops[-2].vin[0].vin]
b_bufs = [u.uop for u in lin.uops[-2].vin[1].vin]
a_bufs = [u.uop for u in lin.uops.uops[-2].vin[0].vin]
b_bufs = [u.uop for u in lin.uops.uops[-2].vin[1].vin]
assert a_bufs == [UOps.LOAD, UOps.CONST]
assert b_bufs == [UOps.CONST, UOps.CONST]
@@ -199,8 +200,8 @@ class TestLinearizer(unittest.TestCase):
MemBuffer(0, dtypes.float, ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))))
lin = Linearizer(ast=ast) # this is a dummy ast
lin.uops = []
return lin.uop(uop, dtype, vin, arg, cachable=False)
lin.uops = UOpGraph()
return lin.uops.add(uop, dtype, vin, arg, cachable=False)
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c0), arg=TernaryOps.WHERE) == c0

View File

@@ -1,17 +1,17 @@
from __future__ import annotations
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator, Sequence
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
import itertools, math, functools
from collections import defaultdict
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.helpers import colored, DEBUG, prod, getenv, all_same, to_function_name
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.features.image import to_image_idx
from tinygrad.codegen.uops import UOps, UOp, get_recursive_children, remove_childless_uops, fix_loop_scope, uops_type_verify
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] # noqa: E501
@@ -41,13 +41,13 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
class Linearizer(Kernel):
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
return self.uop(UOps.ALU, dtype, (a, render_b), op)
return self.uops.add(UOps.ALU, dtype, (a, render_b), op)
# NOTE: the consts have to be cached for deduping of downstream uops to work
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp:
return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
return self.uops.add(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
def get_reduce_acc(self, reduceop:LazyOp):
info = get_lazyop_info(reduceop)
@@ -98,35 +98,36 @@ class Linearizer(Kernel):
key = f"{acc}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
if key not in self.load_cache:
if acc is not None:
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
elif this_const is not None:
self.load_cache[key] = self.const(this_const, localtype)
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501
self.load_cache[key] = self.uops.add(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501
elif isinstance(buf.dtype, ImageDType):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
self.load_cache[key] = self.uop(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
if localtype == localtype.scalar():
idx_small = idx%4
res = idx_small.render(self.render_ops, self)
out = self.uop(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
for ix in range(idx_small.max, idx_small.min, -1):
rvv = self.uop(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
sel = self.uop(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT)
out = self.uop(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE)
rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
sel = self.uops.add(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT)
out = self.uops.add(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE)
self.load_cache[key] = out
else:
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(self.render_ops, self)
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(self.uop(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
@@ -149,7 +150,7 @@ class Linearizer(Kernel):
amt = len(grouped)
idx, valid = self.sts[i].expr_idxs(k)
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
store_offset_new[k] = self.uop(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
store_offset = store_offset_new
stores = []
@@ -157,11 +158,11 @@ class Linearizer(Kernel):
idx, valid = self.sts[i].expr_idxs(_idx)
if isinstance(buf.dtype, ImageDType):
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
else:
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 1: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
return stores
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
@@ -185,30 +186,31 @@ class Linearizer(Kernel):
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
# uops
self.uops: List[UOp] = []
self.uops:UOpGraph = UOpGraph()
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
self.loop_uops: Dict[str, UOp] = {}
# add global buffers
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL,
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
(buf.idx, f"data{buf.idx}"))
self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
(buf.idx, f"data{buf.idx}"))
# add var vals
for i,var in enumerate(self.ast.vars()):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_VAR, dtypes.int32, (), var)
self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduces:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
# kernel name (before late upcast)
self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
@@ -226,7 +228,7 @@ class Linearizer(Kernel):
# global and local loops
def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, (
new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
self.loop_uops.update(new_loops)
@@ -237,11 +239,11 @@ class Linearizer(Kernel):
self.local_size: Optional[List[int]] = None
if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
elif self.opts.has_local:
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
else:
render_loop(loop_global_idxs+loop_local_idxs)
@@ -311,12 +313,12 @@ class Linearizer(Kernel):
upcasts = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]]
for iter in [x[::-1] for x in [x for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]]:
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
ops = (self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
self.uop(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
ret = self.uop(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func)
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
self.uops.add(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
ret = self.uops.add(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func)
for z in range(wmma_sz[2]):
acc[offs[2]+z] = self.uop(UOps.PHI, tc.dtype_out, (op3[z], self.uop(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
else:
assert not locals_to_store, "storing locals isn't supported here"
@@ -333,12 +335,12 @@ class Linearizer(Kernel):
if self.group_for_reduces:
fake_global_idxs = [x*0 for x in global_idxs]
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False)
if self.opts.has_local:
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False)
# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
@@ -385,15 +387,11 @@ class Linearizer(Kernel):
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# optimize the uops
self.uoptimize()
self.uops.uoptimize()
# maybe graph the uops
if DEBUG >= 5:
for u in self.uops:
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
if getenv("GRAPHUOPS"):
from tinygrad.features.graph import graph_uops
graph_uops(self.uops)
if DEBUG >= 5: self.uops.print()
if getenv("GRAPHUOPS"): self.uops.graph()
# restore backups
self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
@@ -402,96 +400,12 @@ class Linearizer(Kernel):
self.applied_opts_cache = self.applied_opts[:]
return self
def uoptimize(self):
# get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
for u in self.uops:
if u.uop == UOps.PHI:
acc_scope[u.vin[0]] += u.vin[2:]
# graph helper functions
@functools.lru_cache(None)
def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]:
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set())
# fix loop scope, push uops upward out of loop if it does not depend on the loop
self.uops = fix_loop_scope(get_recursive_parents, self.uops)
# uops optimization
changed_something = True
while changed_something:
changed_something = False
for u in self.uops:
if u.uop is UOps.PHI and len(u.vin) == 3:
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
# TODO: ADD becomes a MUL, MAX can just become nothing
# NOTE: ADD -> MUL does not fold, this maintains original MULACC code path
if all(x.uop is not UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) \
and u.vin[1].arg is BinaryOps.ADD and u.vin[1].vin[0].arg is not BinaryOps.MUL:
if DEBUG >= 4: print(f"removing PHI node {u}")
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
new = self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u))
# replace u with new
for v in self.uops: v.vin = tuple(new if x is u else x for x in v.vin)
self.uops.remove(u)
changed_something = True
# (recursively) remove childless uops
self.uops = remove_childless_uops(self.uops)
# add UOps.END*
for u in self.uops:
if u.uop is UOps.LOOP:
# add END of loops after the last thing that (recursively) depends on them
self.uop(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(self.uops, u)), key=self.uops.index)[-1])+1) # noqa: E501
elif u.uop is UOps.IF:
# END any if statements at the end of the uops
self.uop(UOps.ENDIF, None, (u,), cachable=False)
# verify the uop types
uops_type_verify(self.uops)
def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501
if simplify:
if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
if uop is UOps.CAST and all(x.uop is UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
return self.const(vin[0].arg, dtype, insert_before)
if uop is UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
# constant folding
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
return self.const(vin[0].arg * vin[1].arg, dtype, insert_before)
# zero folding
for x in [0,1]:
if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg is BinaryOps.SUB and vin[1].uop is UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg is BinaryOps.DIV and vin[1].uop is UOps.CONST and vin[1].arg == 1.0: return vin[0]
key = (uop, dtype, vin, arg)
if insert_before is None: insert_before = len(self.uops)
# check if the cached expr is valid with the given insert place.
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
ret = UOp(uop, dtype, vin, arg)
self.uops.insert(insert_before, ret)
if cachable: self.saved_exprs[key] = ret
return ret
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501
if cache is None: cache = {}
if x in cache: return cache[x]
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op == UnaryOps.CAST:
return [self.uop(UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
return [self.uops.add(UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc
@@ -502,12 +416,12 @@ class Linearizer(Kernel):
ret: List[UOp] = []
input_acc = acc[:]
for val, off in zip(zip(*values), cast(List[int], offs)):
acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)])
acc[off] = self.uops.add(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)])
ret.append(acc[off])
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
else:
ret = [self.uop(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
cache[x] = ret
return ret

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
from typing import List, Set, Optional, Tuple, Any
from tinygrad.helpers import DEBUG, flatten
import functools
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, all_same
from tinygrad.dtype import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.shape.symbolic import sint
from tinygrad.shape.symbolic import sint, Variable
from enum import Enum, auto
from dataclasses import dataclass
@@ -23,96 +25,201 @@ class UOp:
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return deps
UOPS_W_SIDE_EFFECTS = {UOps.STORE}
def remove_childless_uops(uops:List[UOp]) -> List[UOp]:
while 1:
has_child: Set[UOp] = set()
for ru in uops:
for vu in ru.vin:
has_child.add(vu)
nu: List[UOp] = [x for x in uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
if len(nu) == len(uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(uops)} to {len(nu)}")
uops = nu
del nu
return uops
def fix_loop_scope(get_recursive_parents, uops:List[UOp]) -> List[UOp]:
loop_stack: List[List[UOp]] = [[]]
# push uops upward out of loop if it does not depend on the loop
for u in uops:
if not loop_stack[-1]: loop_stack[-1].append(u)
elif u.uop == UOps.LOOP: loop_stack.append([u])
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u)
else:
parents = get_recursive_parents(u, with_phi=True)
# don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here
if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u)
else:
for i in reversed(range(len(loop_stack))):
# check backwards and put the uop in the first encounter with some dependency
if any(x in parents for x in loop_stack[i]) or i == 0:
loop_stack[i].append(u)
break
return flatten(loop_stack)
# optional
def uops_type_verify(uops:List[UOp]):
for u in uops:
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
if uop == UOps.ALU:
if arg in UnaryOps:
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
elif arg in BinaryOps:
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
def uops_alu_resolve(u:UOp) -> sint:
def uop_alu_resolve(u:UOp) -> sint:
if u.uop == UOps.CONST: return u.arg
elif u.uop == UOps.DEFINE_VAR: return u.arg
elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL:
return uops_alu_resolve(u.vin[0]) * uops_alu_resolve(u.vin[1])
return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD:
return uops_alu_resolve(u.vin[0]) + uops_alu_resolve(u.vin[1])
return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
else:
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
def uops_flops_mem(uops:List[UOp]) -> Tuple[sint, sint]:
flops: sint = 0
mem: sint = 0
mults: sint = 1
mult_stack = []
for u in uops:
if u.uop is UOps.LOOP:
mult_stack.append(mults)
mults *= uops_alu_resolve(u.vin[1])
if u.uop is UOps.ENDLOOP:
mults = mult_stack.pop(-1)
if u.uop is UOps.ALU:
flops += mults
if u.uop is UOps.LOAD:
assert u.dtype is not None
mem += u.dtype.itemsize * mults
if u.uop is UOps.STORE:
assert u.vin[2].dtype is not None
mem += u.vin[2].dtype.itemsize * mults
if u.uop is UOps.WMMA:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
else: raise Exception("not implemented")
return flops, mem
class UOpGraph:
def __init__(self):
# list of uops
self.uops: List[UOp] = []
# global uop cache
self.saved_exprs: Dict[Tuple, UOp] = dict()
def __iter__(self): return iter(self.uops)
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
def graph(self):
from tinygrad.features.graph import graph_uops
graph_uops(self.uops)
def print(self):
for u in self.uops:
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501
if simplify:
if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
if uop is UOps.CAST and all(x.uop is UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
if uop is UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
# constant folding
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before)
# zero folding
for x in [0,1]:
if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg is BinaryOps.SUB and vin[1].uop is UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg is BinaryOps.DIV and vin[1].uop is UOps.CONST and vin[1].arg == 1.0: return vin[0]
key = (uop, dtype, vin, arg)
if insert_before is None: insert_before = len(self.uops)
# check if the cached expr is valid with the given insert place.
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
ret = UOp(uop, dtype, vin, arg)
self.uops.insert(insert_before, ret)
if cachable: self.saved_exprs[key] = ret
return ret
def remove_childless(self):
while 1:
has_child: Set[UOp] = set()
for ru in self.uops:
for vu in ru.vin:
has_child.add(vu)
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop is UOps.STORE]
if len(nu) == len(self.uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
# optional
def type_verify(self):
for u in self.uops:
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
if uop == UOps.ALU:
if arg in UnaryOps:
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
elif arg in BinaryOps:
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
def get_recursive_children(self, x:UOp) -> Set[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return deps
def add_ends(self):
for u in self.uops:
if u.uop is UOps.LOOP:
# add END of loops after the last thing that (recursively) depends on them
self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(self.get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501
elif u.uop is UOps.IF:
# END any if statements at the end of the uops
self.add(UOps.ENDIF, None, (u,), cachable=False)
def fix_loop_scope(self, get_recursive_parents):
loop_stack: List[List[UOp]] = [[]]
# push uops upward out of loop if it does not depend on the loop
for u in self.uops:
if not loop_stack[-1]: loop_stack[-1].append(u)
elif u.uop == UOps.LOOP: loop_stack.append([u])
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u)
else:
parents = get_recursive_parents(u, with_phi=True)
# don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here
if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u)
else:
for i in reversed(range(len(loop_stack))):
# check backwards and put the uop in the first encounter with some dependency
if any(x in parents for x in loop_stack[i]) or i == 0:
loop_stack[i].append(u)
break
self.uops = flatten(loop_stack)
def uoptimize(self):
# get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
for u in self.uops:
if u.uop == UOps.PHI:
acc_scope[u.vin[0]] += u.vin[2:]
# graph helper functions
@functools.lru_cache(None)
def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]:
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set())
# fix loop scope, push uops upward out of loop if it does not depend on the loop
self.fix_loop_scope(get_recursive_parents)
# uops optimization
changed_something = True
while changed_something:
changed_something = False
for u in self.uops:
if u.uop is UOps.PHI and len(u.vin) == 3:
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
# TODO: ADD becomes a MUL, MAX can just become nothing
# NOTE: ADD -> MUL does not fold, this maintains original MULACC code path
if all(x.uop is not UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) \
and u.vin[1].arg is BinaryOps.ADD and u.vin[1].vin[0].arg is not BinaryOps.MUL:
if DEBUG >= 4: print(f"removing PHI node {u}")
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
loop_len = self.add(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
if loop_len.dtype != u.dtype: loop_len = self.add(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
new = self.add(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u))
# replace u with new
for v in self.uops: v.vin = tuple(new if x is u else x for x in v.vin)
self.uops.remove(u)
changed_something = True
# (recursively) remove childless uops
self.remove_childless()
# add UOps.END*
self.add_ends()
# verify the uop types
self.type_verify()
def flops_mem(self) -> Tuple[sint, sint]:
flops: sint = 0
mem: sint = 0
mults: sint = 1
mult_stack = []
for u in self.uops:
if u.uop is UOps.LOOP:
mult_stack.append(mults)
mults *= uop_alu_resolve(u.vin[1])
if u.uop is UOps.ENDLOOP:
mults = mult_stack.pop(-1)
if u.uop is UOps.ALU:
flops += mults
if u.uop is UOps.LOAD:
assert u.dtype is not None
mem += u.dtype.itemsize * mults
if u.uop is UOps.STORE:
assert u.vin[2].dtype is not None
mem += u.vin[2].dtype.itemsize * mults
if u.uop is UOps.WMMA:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
else: raise Exception("not implemented")
return flops, mem

View File

@@ -3,7 +3,6 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
import importlib, inspect, functools, pathlib, time, ctypes
from tinygrad.dtype import DType, ImageDType
from tinygrad.codegen.uops import UOps
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable, sym_infer, sint
@@ -236,13 +235,11 @@ class Compiled:
assert self.compiler is not None, "compiler is required to run AST"
k.linearize()
info = get_lazyop_info(k.ast)
from tinygrad.codegen.uops import uops_flops_mem
ops, mem = uops_flops_mem(k.uops)
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
[x.arg for x in k.uops if x.uop is UOps.DEFINE_VAR],
min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops.uops), self, k.global_size, k.local_size,
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
return ret
def get_linearizer(self, ast:LazyOp) -> Linearizer:

View File

@@ -2,7 +2,6 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, functools, random, math, time, multiprocessing, traceback, signal
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner, Compiler
from tinygrad.ops import MemBuffer
from tinygrad.codegen.uops import UOps
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType
from tinygrad.codegen.linearizer import Linearizer
@@ -49,8 +48,8 @@ def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_siz
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]],
List[Variable]]:
lin.linearize()
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
return compiler.compile(src), lin.global_size, lin.local_size, [x.arg for x in lin.uops if x.uop is UOps.DEFINE_VAR]
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops.uops) # NOTE: these all have the same name for deduping
return compiler.compile(src), lin.global_size, lin.local_size, lin.uops.vars()
def _try_compile_linearized_w_idx(x, compiler:Compiler):
try: return (x[0], _compile_linearizer(compiler, x[1], "test"))