start moving image things to rewrite rules (#8678)

* start moving image things to rewrite rules [pr]

* that too

* as expected

* fix

* Revert "fix"

This reverts commit fd03c9464b.
This commit is contained in:
qazal
2025-01-20 06:34:29 -05:00
committed by GitHub
parent b1847d561f
commit 3499a2c72d
2 changed files with 5 additions and 2 deletions

View File

@@ -46,6 +46,7 @@ class TestImageDType(unittest.TestCase):
assert isinstance(it.lazydata.base.realized.dtype, ImageDType)
np.testing.assert_equal(tst, it.numpy())
@unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType
def test_image_cast_and_back_collapses(self):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()

View File

@@ -105,7 +105,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
dtype = buf.dtype.base
# ASSIGN already has a target buffer, otherwise we create a new one
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype.base, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this buffer
ctx.tensor_uops[buf_uop] = [buf]
# (early) bufferize
@@ -198,6 +198,8 @@ to_si = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
# PRELOAD becomes LOAD
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
# once images are loaded they become the base dtype
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
])
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
@@ -482,7 +484,7 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
break_sched = PatternMatcher([
# CONST is always fused and generated
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype.base, x.const_arg).valid(st.st)),
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)),
(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable),
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),