mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
rangeify: don't tag consts, they are global (#12247)
* rangeify: don't tag consts, they are global * don't map movement ops * sym failing test * remove that * update comment * simpler test * work
This commit is contained in:
@@ -425,6 +425,11 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.full((4, 4), 1.).contiguous().realize()
|
||||
check_schedule([a+b, a+b], 1)
|
||||
|
||||
def test_const_realize(self):
|
||||
t = Tensor.ones(2)
|
||||
check_schedule(t[0], 0)
|
||||
check_schedule(t[1], 0)
|
||||
|
||||
def test_fold_double_unary(self):
|
||||
y = Tensor.empty(2)
|
||||
out = y.sum(keepdim=True).sqrt().neg()
|
||||
|
||||
@@ -306,6 +306,10 @@ def might_end_axis(idx:UOp):
|
||||
|
||||
def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}")
|
||||
|
||||
def unprocessed_mop(x:UOp):
|
||||
assert x.src[0].op in GroupOp.Movement.union({*ALWAYS_CONTIGUOUS, Ops.REALIZE, Ops.BUFFERIZE}), f"unprocessed movement op on {x.src[0]}"
|
||||
return x.replace(tag=None)
|
||||
|
||||
pm_rangeify = pm_mops+PatternMatcher([
|
||||
# sink contigs to kick it off
|
||||
(UPat(Ops.REALIZE, src=(UPat(),), name="x", allow_any_len=True), map_realize),
|
||||
@@ -319,8 +323,8 @@ pm_rangeify = pm_mops+PatternMatcher([
|
||||
# if we come across this, remove it. it was a CHILD unused in an INDEX
|
||||
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
|
||||
|
||||
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we INDEX it
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c.replace(tag=None)),
|
||||
# CONST (or DEFINE_VAR) can't have axes. remove INDEX when we get here
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c),
|
||||
|
||||
# handle arg on any op with weight. old endrange stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
|
||||
@@ -340,6 +344,9 @@ pm_rangeify = pm_mops+PatternMatcher([
|
||||
|
||||
# assert if there's any index we didn't process
|
||||
(UPat(GroupOp.All-{Ops.REALIZE, Ops.BUFFERIZE}).f(Ops.INDEX, name="x"), unprocessed_index),
|
||||
|
||||
# if any movement ops make it here they didn't get INDEX, remove tags
|
||||
(UPat(GroupOp.Movement, name="x"), unprocessed_mop),
|
||||
])
|
||||
|
||||
# *****************
|
||||
@@ -556,7 +563,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
|
||||
return x.replace(tag=(len(ctx)-1,))
|
||||
add_tags = PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}, name="x"), tag_uop),
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}, name="x"), tag_uop),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
||||
|
||||
Reference in New Issue
Block a user