mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
start moving image things to rewrite rules (#8678)
* start moving image things to rewrite rules [pr]
* that too
* as expected
* fix
* Revert "fix"
This reverts commit fd03c9464b.
This commit is contained in:
@@ -46,6 +46,7 @@ class TestImageDType(unittest.TestCase):
|
||||
assert isinstance(it.lazydata.base.realized.dtype, ImageDType)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
@unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType
|
||||
def test_image_cast_and_back_collapses(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
|
||||
@@ -105,7 +105,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
dtype = buf.dtype.base
|
||||
# ASSIGN already has a target buffer, otherwise we create a new one
|
||||
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = buf.replace(dtype=dtype.base, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
# track the underlying tensor uop for this buffer
|
||||
ctx.tensor_uops[buf_uop] = [buf]
|
||||
# (early) bufferize
|
||||
@@ -198,6 +198,8 @@ to_si = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
||||
# PRELOAD becomes LOAD
|
||||
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
||||
# once images are loaded they become the base dtype
|
||||
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
])
|
||||
|
||||
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
|
||||
@@ -482,7 +484,7 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# CONST is always fused and generated
|
||||
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype.base, x.const_arg).valid(st.st)),
|
||||
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)),
|
||||
(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable),
|
||||
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
||||
|
||||
Reference in New Issue
Block a user