uops loop removal (#2262)

* remove the loop

* cleanups

* tests failing still

* global_loop_ctx wasn't needed

* replace_op is cleaner

* minor opt

* cast opt was wrong

* uop_num

* uop num was dumb

* tuplize_uops

* torch tests

* fix test_uops
This commit is contained in:
George Hotz
2023-11-10 15:24:47 -08:00
committed by GitHub
parent a753c8e071
commit 85d26ddc36
7 changed files with 97 additions and 63 deletions

View File

@@ -26,8 +26,8 @@ repos:
always_run: true
pass_filenames: false
- id: tests
name: subset of (CPU) tests
entry: env PYTHONPATH="." CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
name: subset of TORCH tests
entry: env PYTHONPATH="." TORCH=1 python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
language: system
always_run: true
pass_filenames: false

View File

@@ -3,7 +3,7 @@ import numpy as np
from collections import Counter, defaultdict
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops
from tinygrad.graph import print_tree
from tinygrad.helpers import ImageDType, prod, getenv
from tinygrad.ops import Device, Compiled, Interpreted
@@ -26,10 +26,7 @@ def fuzz_linearizer(lin: Linearizer):
print(lin.colored_shape())
rawbufs = bufs_from_lin(lin)
# NOTE: copied from beam_search
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
seen_uops = {}
ground_truth = None
while 1:
if len(seen_uops) >= 20: break # enough for this kernel

View File

@@ -140,6 +140,12 @@ class TestOps(unittest.TestCase):
def test_sum_collapse(self):
helper_test_op([], lambda: torch.ones(256,256).sum(axis=1), lambda: Tensor.ones(256,256).sum(axis=1), forward_only=True)
def test_sum_collapse_neg(self):
helper_test_op([], lambda: (-torch.ones(3,3)).sum(axis=1), lambda: (-Tensor.ones(3,3)).sum(axis=1), forward_only=True)
def test_max_dont_collapse(self):
helper_test_op([], lambda: torch.ones(256,256).max(1)[0], lambda: Tensor.ones(256,256).max(1), forward_only=True)
def test_where(self):
helper_test_op(
[(100,)],

View File

@@ -13,7 +13,7 @@ def _uops_to_prg(uops):
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))
uops.append(UOp(uop, dtype, tuple(vin), arg))
return uops[-1]
def _test_single_value(vals, op, dtype):

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
@@ -21,19 +21,13 @@ class UOps(Enum):
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
@dataclass
@dataclass(eq=False)
class UOp:
uop: UOps
dtype: Optional[DType]
vin: Tuple[UOp, ...]
arg: Any
def __repr__(self): return f"{self.num:4d} {str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.num for x in self.vin]):32s} {self.arg}"
#def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str(self.vin):32s} {self.arg}"
# UOps are unique
num: int
def __hash__(self): return self.num
def __eq__(self, x): return self.num == x.num
def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)]
@@ -52,7 +46,7 @@ class Linearizer(Kernel):
return self.uop(UOps.ALU, dtype, (a, render_b), op)
# NOTE: the consts have to be be cached for deduping of downstream uops to work
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b)
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
@@ -215,7 +209,6 @@ class Linearizer(Kernel):
# set global/local size
self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None
global_loop_ctx: Tuple[UOp, ...] = tuple()
if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
@@ -226,7 +219,7 @@ class Linearizer(Kernel):
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs)
render_loop(loop_global_idxs+loop_local_idxs)
# parse AST
loaded_buffers = {}
@@ -302,7 +295,7 @@ class Linearizer(Kernel):
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
for z in range(cast(DType, ret.dtype).sz):
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx)
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
i += wmma_sz[2]
else:
if locals_to_store:
@@ -314,7 +307,7 @@ class Linearizer(Kernel):
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
# run early AST (with reduce)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
# end the reduce loop
self.load_cache.clear()
@@ -365,11 +358,54 @@ class Linearizer(Kernel):
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
# run late AST
val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx)
val = self.ast_parse(self.ast, acc, None, loaded_buffers)
# store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# graph helper functions
def get_recursive_parents(x:List[UOp]) -> List[UOp]:
ret: Set[UOp] = set()
this_round: Set[UOp] = set(x)
while len(this_round):
ret = ret.union(this_round)
next_round: Set[UOp] = set()
for r in this_round: next_round = next_round.union(set(r.vin))
this_round = next_round
return list(ret)
def get_recursive_children(x:UOp) -> List[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return sorted(list(deps), key=self.uops.index) # get the last one
def replace_op(old:UOp, new:UOp):
for u in self.uops:
u.vin = tuple(new if x is old else x for x in u.vin)
self.uops.remove(old)
# uops optimization
changed_something = True
while changed_something:
changed_something = False
for u in self.uops:
if u.uop == UOps.PHI and len(u.vin) == 3:
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
# TODO: ADD becomes a MUL, MAX can just become nothing
if all(x.uop != UOps.LOOP for x in get_recursive_parents(list(u.vin[0:2]))) and u.vin[1].arg == BinaryOps.ADD:
if DEBUG >= 4: print(f"removing PHI node {u}")
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u)))
changed_something = True
# (recursively) remove childless uops
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
@@ -382,32 +418,21 @@ class Linearizer(Kernel):
if len(nu) == len(self.uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
del nu
def get_recursive_deps(x:UOp) -> List[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return sorted(list(deps), key=lambda x: x.num)
# add END of loops after the last thing that (recursively) depends on them
# and END any if statements
# add UOps.END
for u in self.uops:
if u.uop == UOps.LOOP:
last_phi = self.uops.index(get_recursive_deps(u)[-1])
at_end = self.uops[last_phi+1:]
self.uops = self.uops[:last_phi+1]
self.uop(UOps.END, None, (u,), cachable=False)
self.uops += at_end
# add END of loops after the last thing that (recursively) depends on them
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(get_recursive_children(u)[-1])+1)
elif u.uop == UOps.IF:
# END any if statements at the end of the uops
self.uop(UOps.END, None, (u,), cachable=False)
# maybe graph the uops
if DEBUG >= 5:
for u in self.uops: print(u)
for u in self.uops:
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
if getenv("GRAPHUOPS"):
from tinygrad.graph import graph_uops
graph_uops(self.uops)
@@ -419,27 +444,32 @@ class Linearizer(Kernel):
self.applied_opts_cache = self.applied_opts[:]
return self
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp:
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
key = (uop, dtype, vin, arg)
if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype)
if uop == UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype)
# zero folding
for x in [0,1]:
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if simplify:
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): return self.const(vin[0].arg, dtype, insert_before)
if uop == UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
# zero folding
for x in [0,1]:
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
#if DEBUG >= 5: print(self.uops[-1])
if cachable: self.saved_exprs[key] = self.uops[-1]
return self.uops[-1]
ret = UOp(uop, dtype, vin, arg)
if insert_before is not None:
self.uops.insert(insert_before, ret)
else:
self.uops.append(ret)
if cachable: self.saved_exprs[key] = ret
return ret
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER

View File

@@ -3,7 +3,7 @@ import itertools, random
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
from tinygrad.tensor import Tensor
@@ -97,6 +97,8 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
pass
return acted_lins
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
@@ -108,7 +110,6 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))]
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
while 1:

View File

@@ -114,8 +114,8 @@ def graph_uops(uops):
G = nx.DiGraph()
for u in uops:
if u.uop == UOps.END: continue
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
for v in u.vin: G.add_edge(v.num, u.num)
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
GRAPHPATH = "/tmp/uops"
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')