mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
late load
This commit is contained in:
@@ -43,7 +43,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx.valid(new_valid) if new_valid is not None else idx)
|
||||
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
|
||||
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
@@ -52,7 +52,7 @@ load_store_indexing = PatternMatcher([
|
||||
# simplify away long after index has been lowered
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)),
|
||||
])
|
||||
|
||||
# ***** load/store grouping *****
|
||||
@@ -302,11 +302,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,))
|
||||
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
|
||||
acc.index(UOp.const(dtypes.int, 0)).store(identity)
|
||||
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element
|
||||
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
if len(reduce_range) == 0: return ret
|
||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)).load()
|
||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
|
||||
@@ -133,7 +133,7 @@ def no_load(u:UOp) -> bool: return not any(x.op is Ops.LOAD for x in u.backward_
|
||||
pm_load_collapse = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_load_collapse),
|
||||
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
|
||||
#((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
|
||||
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
|
||||
])
|
||||
|
||||
def cut_store_range(ctx, store:UOp, r:UOp):
|
||||
|
||||
@@ -339,8 +339,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, *idx): return self.index(*idx)
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
@@ -1199,10 +1199,11 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
|
||||
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)),
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
|
||||
Reference in New Issue
Block a user