mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Refactor to UOpGraph class (#3566)
* Refactor to UOpGraph class * fix test
This commit is contained in:
@@ -13,6 +13,7 @@ from tinygrad.features.jit import CacheCollector
|
||||
from tinygrad.realize import create_schedule, run_schedule
|
||||
from tinygrad.helpers import prod, Context
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
@@ -54,8 +55,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = Linearizer(ast)
|
||||
lin.linearize()
|
||||
|
||||
a_bufs = [u.uop for u in lin.uops[-2].vin[0].vin]
|
||||
b_bufs = [u.uop for u in lin.uops[-2].vin[1].vin]
|
||||
a_bufs = [u.uop for u in lin.uops.uops[-2].vin[0].vin]
|
||||
b_bufs = [u.uop for u in lin.uops.uops[-2].vin[1].vin]
|
||||
|
||||
assert a_bufs == [UOps.LOAD, UOps.CONST]
|
||||
assert b_bufs == [UOps.CONST, UOps.CONST]
|
||||
@@ -199,8 +200,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
MemBuffer(0, dtypes.float, ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))))
|
||||
lin = Linearizer(ast=ast) # this is a dummy ast
|
||||
|
||||
lin.uops = []
|
||||
return lin.uop(uop, dtype, vin, arg, cachable=False)
|
||||
lin.uops = UOpGraph()
|
||||
return lin.uops.add(uop, dtype, vin, arg, cachable=False)
|
||||
|
||||
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
|
||||
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c0), arg=TernaryOps.WHERE) == c0
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Set, Iterator, Sequence
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
|
||||
import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.helpers import colored, DEBUG, prod, getenv, all_same, to_function_name
|
||||
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
from tinygrad.features.image import to_image_idx
|
||||
|
||||
from tinygrad.codegen.uops import UOps, UOp, get_recursive_children, remove_childless_uops, fix_loop_scope, uops_type_verify
|
||||
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
||||
|
||||
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] # noqa: E501
|
||||
@@ -41,13 +41,13 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
|
||||
class Linearizer(Kernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
||||
return self.uops.add(UOps.ALU, dtype, (a, render_b), op)
|
||||
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp:
|
||||
return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
||||
return self.uops.add(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
||||
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
def get_reduce_acc(self, reduceop:LazyOp):
|
||||
info = get_lazyop_info(reduceop)
|
||||
@@ -98,35 +98,36 @@ class Linearizer(Kernel):
|
||||
key = f"{acc}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
|
||||
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501
|
||||
self.load_cache[key] = self.uops.add(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) # noqa: E501
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
|
||||
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
if localtype == localtype.scalar():
|
||||
idx_small = idx%4
|
||||
res = idx_small.render(self.render_ops, self)
|
||||
out = self.uop(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
||||
out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
||||
for ix in range(idx_small.max, idx_small.min, -1):
|
||||
rvv = self.uop(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
||||
sel = self.uop(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT)
|
||||
out = self.uop(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE)
|
||||
rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
||||
sel = self.uops.add(UOps.ALU, dtypes.bool, (res, self.const(ix)), BinaryOps.CMPLT)
|
||||
out = self.uops.add(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE)
|
||||
self.load_cache[key] = out
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
ret.append(self.uop(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
|
||||
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
||||
@@ -149,7 +150,7 @@ class Linearizer(Kernel):
|
||||
amt = len(grouped)
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
|
||||
store_offset_new[k] = self.uop(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
||||
store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
||||
store_offset = store_offset_new
|
||||
|
||||
stores = []
|
||||
@@ -157,11 +158,11 @@ class Linearizer(Kernel):
|
||||
idx, valid = self.sts[i].expr_idxs(_idx)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
if valid.min == 1: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
else: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
return stores
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
@@ -185,30 +186,31 @@ class Linearizer(Kernel):
|
||||
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
|
||||
|
||||
# uops
|
||||
self.uops: List[UOp] = []
|
||||
self.uops:UOpGraph = UOpGraph()
|
||||
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
||||
self.loop_uops: Dict[str, UOp] = {}
|
||||
|
||||
# add global buffers
|
||||
for i,buf in enumerate(self.bufs):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL,
|
||||
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||
(buf.idx, f"data{buf.idx}"))
|
||||
self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
|
||||
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||
(buf.idx, f"data{buf.idx}"))
|
||||
# add var vals
|
||||
for i,var in enumerate(self.ast.vars()):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_VAR, dtypes.int32, (), var)
|
||||
self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
|
||||
PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
|
||||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduces:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
||||
temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
|
||||
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
@@ -226,7 +228,7 @@ class Linearizer(Kernel):
|
||||
|
||||
# global and local loops
|
||||
def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
|
||||
new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
||||
new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
self.loop_uops.update(new_loops)
|
||||
@@ -237,11 +239,11 @@ class Linearizer(Kernel):
|
||||
self.local_size: Optional[List[int]] = None
|
||||
if self.dont_use_locals:
|
||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
||||
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
||||
elif self.opts.has_local:
|
||||
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
|
||||
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
||||
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
|
||||
else:
|
||||
render_loop(loop_global_idxs+loop_local_idxs)
|
||||
|
||||
@@ -311,12 +313,12 @@ class Linearizer(Kernel):
|
||||
upcasts = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]]
|
||||
for iter in [x[::-1] for x in [x for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]]:
|
||||
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
||||
ops = (self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
self.uop(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uop(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func)
|
||||
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
self.uops.add(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uops.add(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func)
|
||||
for z in range(wmma_sz[2]):
|
||||
acc[offs[2]+z] = self.uop(UOps.PHI, tc.dtype_out, (op3[z], self.uop(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
else:
|
||||
assert not locals_to_store, "storing locals isn't supported here"
|
||||
|
||||
@@ -333,12 +335,12 @@ class Linearizer(Kernel):
|
||||
if self.group_for_reduces:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
|
||||
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False)
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
||||
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
|
||||
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
||||
@@ -385,15 +387,11 @@ class Linearizer(Kernel):
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
|
||||
# optimize the uops
|
||||
self.uoptimize()
|
||||
self.uops.uoptimize()
|
||||
|
||||
# maybe graph the uops
|
||||
if DEBUG >= 5:
|
||||
for u in self.uops:
|
||||
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
|
||||
if getenv("GRAPHUOPS"):
|
||||
from tinygrad.features.graph import graph_uops
|
||||
graph_uops(self.uops)
|
||||
if DEBUG >= 5: self.uops.print()
|
||||
if getenv("GRAPHUOPS"): self.uops.graph()
|
||||
|
||||
# restore backups
|
||||
self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
|
||||
@@ -402,96 +400,12 @@ class Linearizer(Kernel):
|
||||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def uoptimize(self):
|
||||
# get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
|
||||
acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
for u in self.uops:
|
||||
if u.uop == UOps.PHI:
|
||||
acc_scope[u.vin[0]] += u.vin[2:]
|
||||
|
||||
# graph helper functions
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]:
|
||||
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set())
|
||||
|
||||
# fix loop scope, push uops upward out of loop if it does not depend on the loop
|
||||
self.uops = fix_loop_scope(get_recursive_parents, self.uops)
|
||||
|
||||
# uops optimization
|
||||
changed_something = True
|
||||
while changed_something:
|
||||
changed_something = False
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.PHI and len(u.vin) == 3:
|
||||
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
|
||||
# TODO: ADD becomes a MUL, MAX can just become nothing
|
||||
# NOTE: ADD -> MUL does not fold, this maintains original MULACC code path
|
||||
if all(x.uop is not UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) \
|
||||
and u.vin[1].arg is BinaryOps.ADD and u.vin[1].vin[0].arg is not BinaryOps.MUL:
|
||||
if DEBUG >= 4: print(f"removing PHI node {u}")
|
||||
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
||||
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
||||
loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
|
||||
if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
|
||||
new = self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u))
|
||||
# replace u with new
|
||||
for v in self.uops: v.vin = tuple(new if x is u else x for x in v.vin)
|
||||
self.uops.remove(u)
|
||||
changed_something = True
|
||||
|
||||
# (recursively) remove childless uops
|
||||
self.uops = remove_childless_uops(self.uops)
|
||||
|
||||
# add UOps.END*
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.LOOP:
|
||||
# add END of loops after the last thing that (recursively) depends on them
|
||||
self.uop(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(self.uops, u)), key=self.uops.index)[-1])+1) # noqa: E501
|
||||
elif u.uop is UOps.IF:
|
||||
# END any if statements at the end of the uops
|
||||
self.uop(UOps.ENDIF, None, (u,), cachable=False)
|
||||
|
||||
# verify the uop types
|
||||
uops_type_verify(self.uops)
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501
|
||||
if simplify:
|
||||
if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
||||
if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
|
||||
if uop is UOps.CAST and all(x.uop is UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
|
||||
return self.const(vin[0].arg, dtype, insert_before)
|
||||
if uop is UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
|
||||
return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
|
||||
# constant folding
|
||||
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
|
||||
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
|
||||
return self.const(vin[0].arg * vin[1].arg, dtype, insert_before)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
||||
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
||||
if arg is BinaryOps.SUB and vin[1].uop is UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
||||
if arg is BinaryOps.DIV and vin[1].uop is UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
|
||||
key = (uop, dtype, vin, arg)
|
||||
if insert_before is None: insert_before = len(self.uops)
|
||||
# check if the cached expr is valid with the given insert place.
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
||||
ret = UOp(uop, dtype, vin, arg)
|
||||
self.uops.insert(insert_before, ret)
|
||||
if cachable: self.saved_exprs[key] = ret
|
||||
return ret
|
||||
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501
|
||||
if cache is None: cache = {}
|
||||
if x in cache: return cache[x]
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op == UnaryOps.CAST:
|
||||
return [self.uop(UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
return [self.uops.add(UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
@@ -502,12 +416,12 @@ class Linearizer(Kernel):
|
||||
ret: List[UOp] = []
|
||||
input_acc = acc[:]
|
||||
for val, off in zip(zip(*values), cast(List[int], offs)):
|
||||
acc[off] = self.uop(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)])
|
||||
acc[off] = self.uops.add(UOps.ALU, acc[off].dtype, vin=val+(acc[off],), arg=ops[cast(ReduceOps, x.op)])
|
||||
ret.append(acc[off])
|
||||
for off in range(len(acc)):
|
||||
if input_acc[off] != acc[off]:
|
||||
acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
else:
|
||||
ret = [self.uop(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
|
||||
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
|
||||
cache[x] = ret
|
||||
return ret
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Set, Optional, Tuple, Any
|
||||
from tinygrad.helpers import DEBUG, flatten
|
||||
import functools
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG, flatten, all_same
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -23,96 +25,201 @@ class UOp:
|
||||
def __repr__(self):
|
||||
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
||||
|
||||
def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]:
|
||||
deps = set([x])
|
||||
ssize = 0
|
||||
while ssize != len(deps):
|
||||
ssize = len(deps)
|
||||
for u in uops:
|
||||
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
||||
deps.add(u)
|
||||
return deps
|
||||
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE}
|
||||
def remove_childless_uops(uops:List[UOp]) -> List[UOp]:
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in uops:
|
||||
for vu in ru.vin:
|
||||
has_child.add(vu)
|
||||
nu: List[UOp] = [x for x in uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
|
||||
if len(nu) == len(uops): break
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(uops)} to {len(nu)}")
|
||||
uops = nu
|
||||
del nu
|
||||
return uops
|
||||
|
||||
def fix_loop_scope(get_recursive_parents, uops:List[UOp]) -> List[UOp]:
|
||||
loop_stack: List[List[UOp]] = [[]]
|
||||
# push uops upward out of loop if it does not depend on the loop
|
||||
for u in uops:
|
||||
if not loop_stack[-1]: loop_stack[-1].append(u)
|
||||
elif u.uop == UOps.LOOP: loop_stack.append([u])
|
||||
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u)
|
||||
else:
|
||||
parents = get_recursive_parents(u, with_phi=True)
|
||||
# don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here
|
||||
if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u)
|
||||
else:
|
||||
for i in reversed(range(len(loop_stack))):
|
||||
# check backwards and put the uop in the first encounter with some dependency
|
||||
if any(x in parents for x in loop_stack[i]) or i == 0:
|
||||
loop_stack[i].append(u)
|
||||
break
|
||||
return flatten(loop_stack)
|
||||
|
||||
# optional
|
||||
def uops_type_verify(uops:List[UOp]):
|
||||
for u in uops:
|
||||
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
|
||||
if uop == UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
|
||||
|
||||
def uops_alu_resolve(u:UOp) -> sint:
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.uop == UOps.CONST: return u.arg
|
||||
elif u.uop == UOps.DEFINE_VAR: return u.arg
|
||||
elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL:
|
||||
return uops_alu_resolve(u.vin[0]) * uops_alu_resolve(u.vin[1])
|
||||
return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
|
||||
elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD:
|
||||
return uops_alu_resolve(u.vin[0]) + uops_alu_resolve(u.vin[1])
|
||||
return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
|
||||
else:
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
||||
|
||||
def uops_flops_mem(uops:List[UOp]) -> Tuple[sint, sint]:
|
||||
flops: sint = 0
|
||||
mem: sint = 0
|
||||
mults: sint = 1
|
||||
mult_stack = []
|
||||
for u in uops:
|
||||
if u.uop is UOps.LOOP:
|
||||
mult_stack.append(mults)
|
||||
mults *= uops_alu_resolve(u.vin[1])
|
||||
if u.uop is UOps.ENDLOOP:
|
||||
mults = mult_stack.pop(-1)
|
||||
if u.uop is UOps.ALU:
|
||||
flops += mults
|
||||
if u.uop is UOps.LOAD:
|
||||
assert u.dtype is not None
|
||||
mem += u.dtype.itemsize * mults
|
||||
if u.uop is UOps.STORE:
|
||||
assert u.vin[2].dtype is not None
|
||||
mem += u.vin[2].dtype.itemsize * mults
|
||||
if u.uop is UOps.WMMA:
|
||||
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
|
||||
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
|
||||
else: raise Exception("not implemented")
|
||||
return flops, mem
|
||||
class UOpGraph:
|
||||
def __init__(self):
|
||||
# list of uops
|
||||
self.uops: List[UOp] = []
|
||||
|
||||
# global uop cache
|
||||
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
||||
|
||||
def __iter__(self): return iter(self.uops)
|
||||
|
||||
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
|
||||
|
||||
def graph(self):
|
||||
from tinygrad.features.graph import graph_uops
|
||||
graph_uops(self.uops)
|
||||
|
||||
def print(self):
|
||||
for u in self.uops:
|
||||
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
|
||||
|
||||
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501
|
||||
if simplify:
|
||||
if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
||||
if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
|
||||
if uop is UOps.CAST and all(x.uop is UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
|
||||
return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
|
||||
if uop is UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
|
||||
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
|
||||
# constant folding
|
||||
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
|
||||
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
|
||||
return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
||||
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
||||
if arg is BinaryOps.SUB and vin[1].uop is UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
||||
if arg is BinaryOps.DIV and vin[1].uop is UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
|
||||
key = (uop, dtype, vin, arg)
|
||||
if insert_before is None: insert_before = len(self.uops)
|
||||
# check if the cached expr is valid with the given insert place.
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
||||
ret = UOp(uop, dtype, vin, arg)
|
||||
self.uops.insert(insert_before, ret)
|
||||
if cachable: self.saved_exprs[key] = ret
|
||||
return ret
|
||||
|
||||
def remove_childless(self):
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in self.uops:
|
||||
for vu in ru.vin:
|
||||
has_child.add(vu)
|
||||
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop is UOps.STORE]
|
||||
if len(nu) == len(self.uops): break
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
|
||||
# optional
|
||||
def type_verify(self):
|
||||
for u in self.uops:
|
||||
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
|
||||
if uop == UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
|
||||
|
||||
def get_recursive_children(self, x:UOp) -> Set[UOp]:
|
||||
deps = set([x])
|
||||
ssize = 0
|
||||
while ssize != len(deps):
|
||||
ssize = len(deps)
|
||||
for u in self.uops:
|
||||
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
||||
deps.add(u)
|
||||
return deps
|
||||
|
||||
def add_ends(self):
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.LOOP:
|
||||
# add END of loops after the last thing that (recursively) depends on them
|
||||
self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(self.get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501
|
||||
elif u.uop is UOps.IF:
|
||||
# END any if statements at the end of the uops
|
||||
self.add(UOps.ENDIF, None, (u,), cachable=False)
|
||||
|
||||
def fix_loop_scope(self, get_recursive_parents):
|
||||
loop_stack: List[List[UOp]] = [[]]
|
||||
# push uops upward out of loop if it does not depend on the loop
|
||||
for u in self.uops:
|
||||
if not loop_stack[-1]: loop_stack[-1].append(u)
|
||||
elif u.uop == UOps.LOOP: loop_stack.append([u])
|
||||
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u)
|
||||
else:
|
||||
parents = get_recursive_parents(u, with_phi=True)
|
||||
# don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here
|
||||
if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u)
|
||||
else:
|
||||
for i in reversed(range(len(loop_stack))):
|
||||
# check backwards and put the uop in the first encounter with some dependency
|
||||
if any(x in parents for x in loop_stack[i]) or i == 0:
|
||||
loop_stack[i].append(u)
|
||||
break
|
||||
self.uops = flatten(loop_stack)
|
||||
|
||||
def uoptimize(self):
|
||||
# get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
|
||||
acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
for u in self.uops:
|
||||
if u.uop == UOps.PHI:
|
||||
acc_scope[u.vin[0]] += u.vin[2:]
|
||||
|
||||
# graph helper functions
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]:
|
||||
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set())
|
||||
|
||||
# fix loop scope, push uops upward out of loop if it does not depend on the loop
|
||||
self.fix_loop_scope(get_recursive_parents)
|
||||
|
||||
# uops optimization
|
||||
changed_something = True
|
||||
while changed_something:
|
||||
changed_something = False
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.PHI and len(u.vin) == 3:
|
||||
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
|
||||
# TODO: ADD becomes a MUL, MAX can just become nothing
|
||||
# NOTE: ADD -> MUL does not fold, this maintains original MULACC code path
|
||||
if all(x.uop is not UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) \
|
||||
and u.vin[1].arg is BinaryOps.ADD and u.vin[1].vin[0].arg is not BinaryOps.MUL:
|
||||
if DEBUG >= 4: print(f"removing PHI node {u}")
|
||||
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
||||
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
||||
loop_len = self.add(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
|
||||
if loop_len.dtype != u.dtype: loop_len = self.add(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
|
||||
new = self.add(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u))
|
||||
# replace u with new
|
||||
for v in self.uops: v.vin = tuple(new if x is u else x for x in v.vin)
|
||||
self.uops.remove(u)
|
||||
changed_something = True
|
||||
|
||||
# (recursively) remove childless uops
|
||||
self.remove_childless()
|
||||
|
||||
# add UOps.END*
|
||||
self.add_ends()
|
||||
|
||||
# verify the uop types
|
||||
self.type_verify()
|
||||
|
||||
def flops_mem(self) -> Tuple[sint, sint]:
|
||||
flops: sint = 0
|
||||
mem: sint = 0
|
||||
mults: sint = 1
|
||||
mult_stack = []
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.LOOP:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.vin[1])
|
||||
if u.uop is UOps.ENDLOOP:
|
||||
mults = mult_stack.pop(-1)
|
||||
if u.uop is UOps.ALU:
|
||||
flops += mults
|
||||
if u.uop is UOps.LOAD:
|
||||
assert u.dtype is not None
|
||||
mem += u.dtype.itemsize * mults
|
||||
if u.uop is UOps.STORE:
|
||||
assert u.vin[2].dtype is not None
|
||||
mem += u.vin[2].dtype.itemsize * mults
|
||||
if u.uop is UOps.WMMA:
|
||||
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
|
||||
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
|
||||
else: raise Exception("not implemented")
|
||||
return flops, mem
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
|
||||
import importlib, inspect, functools, pathlib, time, ctypes
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
@@ -236,13 +235,11 @@ class Compiled:
|
||||
assert self.compiler is not None, "compiler is required to run AST"
|
||||
k.linearize()
|
||||
info = get_lazyop_info(k.ast)
|
||||
from tinygrad.codegen.uops import uops_flops_mem
|
||||
ops, mem = uops_flops_mem(k.uops)
|
||||
ops, mem = k.uops.flops_mem()
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
|
||||
[x.arg for x in k.uops if x.uop is UOps.DEFINE_VAR],
|
||||
min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops.uops), self, k.global_size, k.local_size,
|
||||
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
return ret
|
||||
|
||||
def get_linearizer(self, ast:LazyOp) -> Linearizer:
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner, Compiler
|
||||
from tinygrad.ops import MemBuffer
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
@@ -49,8 +48,8 @@ def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_siz
|
||||
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]],
|
||||
List[Variable]]:
|
||||
lin.linearize()
|
||||
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
||||
return compiler.compile(src), lin.global_size, lin.local_size, [x.arg for x in lin.uops if x.uop is UOps.DEFINE_VAR]
|
||||
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops.uops) # NOTE: these all have the same name for deduping
|
||||
return compiler.compile(src), lin.global_size, lin.local_size, lin.uops.vars()
|
||||
|
||||
def _try_compile_linearized_w_idx(x, compiler:Compiler):
|
||||
try: return (x[0], _compile_linearizer(compiler, x[1], "test"))
|
||||
|
||||
Reference in New Issue
Block a user