diff --git a/test/test_schedule.py b/test/test_schedule.py index 0f71dae331..20ca1ccd2c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2366,5 +2366,56 @@ class TestContiguous(unittest.TestCase): b = a.expand((4, 4)).contiguous().contiguous() check_schedule(b, 1) + +class TestUOpBecome(unittest.TestCase): + # the simplest case, if we create a new BUFFER for this UOp + def test_new_buffer(self): + a = Tensor.empty(4, 4) + b = Tensor.empty(4, 4) + add = a+b + check_schedule(add, 1) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + + def test_new_buffer_view(self): + a = Tensor.empty(4, 4) + b = Tensor.empty(4, 4) + add = (a+b).reshape(8, 2) + check_schedule(add, 1) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + # VIEW is preserverd after the becomes rewrite. + self.assertEqual(add.lazydata.shape, (8, 2)) + assert add.lazydata is not add.lazydata.base + + def test_become_existing_buffer(self): + a = Tensor.empty(4, 4) + b = a*1 + assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + check_schedule(b, 0) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer) + + def test_become_const_in_base(self): + a = Tensor.empty(4) + b = a*0 + assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + check_schedule(b, 0) + assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + + def test_become_const_in_view(self): + # if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged. + add = Tensor.empty(2, 2)+Tensor.empty(2, 2) + b = add.shrink(((0, 1), (0, 0))) + check_schedule(b, 0) + assert UPat(Ops.CONST, arg=0).match(b.lazydata, {}) + self.assertEqual(b.shape, (1, 0)) + # the base is untouched. + assert UPat(Ops.ADD).match(add.lazydata, {}) + + def test_become_const_from_const(self): + const_add = Tensor(1)+Tensor(2) + assert UPat(Ops.ADD).match(const_add.lazydata, {}) + check_schedule(const_add, 0) + assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {}) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index dc8d0b64aa..6ea1b011a4 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -71,9 +71,9 @@ class TestTensorUopRepresentation(unittest.TestCase): def test_viewed_consts_do_not_realize(self): a = Tensor.ones(10, 10) print(a.lazydata) - pre_realize = a.lazydata a.realize() - assert a.lazydata is pre_realize + is_pattern(a, const_pattern) + self.assertEqual(a.lazydata.shape, (10, 10)) # currently, CONSTs have a "fake" BUFFER. this should be fixed # current: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1c5931962e..71cc12492c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -514,10 +514,14 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu for buf_uop in store_uops: for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) - # tensors can become an existing buffer, no ScheduleItem needed + # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed for k,v in tensor_map.items(): - # NOTE: we only add base tensors to becomes_map - if k is not v and v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + # NOOP + if k.base is v.base: continue + # NOTE: only the base tensors get a BUFFER UOp + if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + # otherwise if it simplified to a CONST the UOp just becomes that CONST + elif v.op is Ops.CONST: ctx.becomes_map[k] = v # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs}