late load

This commit is contained in:
George Hotz
2025-10-30 09:20:51 +08:00
parent 2ef53a7a90
commit cd8272f129
3 changed files with 11 additions and 10 deletions

View File

@@ -43,7 +43,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx.valid(new_valid) if new_valid is not None else idx)
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
load_store_indexing = PatternMatcher([
@@ -52,7 +52,7 @@ load_store_indexing = PatternMatcher([
# simplify away long after index has been lowered
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)),
])
# ***** load/store grouping *****
@@ -302,11 +302,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,))
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
acc.index(UOp.const(dtypes.int, 0)).store(identity)
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
if len(reduce_range) == 0: return ret
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)).load()
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN

View File

@@ -133,7 +133,7 @@ def no_load(u:UOp) -> bool: return not any(x.op is Ops.LOAD for x in u.backward_
pm_load_collapse = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_load_collapse),
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
#((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
])
def cut_store_range(ctx, store:UOp, r:UOp):

View File

@@ -339,8 +339,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, *idx): return self.index(*idx)
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
@@ -1199,10 +1199,11 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)),
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),