rangeify: don't tag consts, they are global (#12247)

* rangeify: don't tag consts, they are global

* don't map movement ops

* sym failing test

* remove that

* update comment

* simpler test

* work
This commit is contained in:
qazal
2025-09-19 15:25:03 +03:00
committed by GitHub
parent cc038b31b6
commit bb59eed82f
2 changed files with 15 additions and 3 deletions

View File

@@ -425,6 +425,11 @@ class TestSchedule(unittest.TestCase):
b = Tensor.full((4, 4), 1.).contiguous().realize()
check_schedule([a+b, a+b], 1)
def test_const_realize(self):
t = Tensor.ones(2)
check_schedule(t[0], 0)
check_schedule(t[1], 0)
def test_fold_double_unary(self):
y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().neg()

View File

@@ -306,6 +306,10 @@ def might_end_axis(idx:UOp):
def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}")
def unprocessed_mop(x:UOp):
assert x.src[0].op in GroupOp.Movement.union({*ALWAYS_CONTIGUOUS, Ops.REALIZE, Ops.BUFFERIZE}), f"unprocessed movement op on {x.src[0]}"
return x.replace(tag=None)
pm_rangeify = pm_mops+PatternMatcher([
# sink contigs to kick it off
(UPat(Ops.REALIZE, src=(UPat(),), name="x", allow_any_len=True), map_realize),
@@ -319,8 +323,8 @@ pm_rangeify = pm_mops+PatternMatcher([
# if we come across this, remove it. it was a CHILD unused in an INDEX
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we INDEX it
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c.replace(tag=None)),
# CONST (or DEFINE_VAR) can't have axes. remove INDEX when we get here
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c),
# handle arg on any op with weight. old endrange stuff
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
@@ -340,6 +344,9 @@ pm_rangeify = pm_mops+PatternMatcher([
# assert if there's any index we didn't process
(UPat(GroupOp.All-{Ops.REALIZE, Ops.BUFFERIZE}).f(Ops.INDEX, name="x"), unprocessed_index),
# if any movement ops make it here they didn't get INDEX, remove tags
(UPat(GroupOp.Movement, name="x"), unprocessed_mop),
])
# *****************
@@ -556,7 +563,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}, name="x"), tag_uop),
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}, name="x"), tag_uop),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)