diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 17a725b3bf..6d00991f8d 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -128,7 +128,7 @@ class TestRealWorld(unittest.TestCase): loss.backward() optimizer.step() - helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 123) + helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 126) @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16") def test_train_cifar_hyp(self): diff --git a/test/test_schedule.py b/test/test_schedule.py index d2b1195a56..b11600291f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 9) + check_schedule(opt.schedule_step(), 13) def test_sgd_4convs_fuse(self): with Tensor.train(): @@ -2502,8 +2502,8 @@ class TestUOpBecome(unittest.TestCase): check_schedule(add, 1) # NOTE: realized base is always a flat buffer assert UPat(Ops.BUFFER).match(add.lazydata.base, {}) - # the Tensor UOp can optionally stack movement ops on top of BUFFER, in this case to preserve the (4, 4) shape of the tensor - assert UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)).match(add.lazydata, {}) + # the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor + assert add.lazydata is not add.lazydata.base self.assertEqual(add.lazydata.size, 16) self.assertEqual(add.lazydata.shape, (4, 4)) @@ -2528,19 +2528,19 @@ class TestUOpBecome(unittest.TestCase): # sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer @unittest.expectedFailure - def test_new_buffer_mops(self): + def test_reorder_expand(self): a = Tensor.empty(4, 1) b = a.expand(4, 4).reciprocal() check_schedule(b, 1) - self.assertEqual(b.lazydata.base.realized.size, 4) - assert UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE),)).match(b.lazydata, {}), f"{b.lazydata}" + self.assertEqual(b.lazydata.base.buffer.size, 4) + self.assertEqual(b.lazydata.st, ShapeTracker.from_shape((4, 1)).expand((4, 4))) def test_become_existing_buffer(self): a = Tensor.empty(4, 4) b = a*1 assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul check_schedule(b, 0) - assert UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling backtracks to the movement op if the realized tensor becomes a view + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling merges all MovementOps into a single VIEW self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer) def test_become_buf_with_mops(self): @@ -2591,20 +2591,20 @@ class TestUOpBecome(unittest.TestCase): a = Tensor.empty(4, 4) b = a.permute((1, 0))+0 check_schedule(b, 0) - self.assertIs(b.lazydata, a.lazydata.permute((1, 0))) + self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).st) def test_become_existing_buf_view_alt(self): a = Tensor.empty(4, 4) b = a.permute((1, 0)).reshape((8, 2))+0 check_schedule(b, 0) - self.assertIs(b.lazydata, a.lazydata.permute((1, 0)).reshape((8, 2))) + self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st) # they can also have other base parents that simplified, in that case we just backtrack to the chained mops def test_become_existing_buf_complex(self): a = Tensor.empty(4, 4) b = (a.permute((1, 0))+0).reshape((8, 2))+0 check_schedule(b, 0) - self.assertIs(b.lazydata, a.lazydata.permute((1, 0)).reshape((8, 2))) + self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st) assert b.lazydata.is_realized def test_become_multiple_choices(self): @@ -2615,9 +2615,8 @@ class TestUOpBecome(unittest.TestCase): assert all_same([x.lazydata.base.realized for x in [a,b,c]]) # these movement ops result in the same ShapeTracker assert b.lazydata.st == c.lazydata.st - # the decision for which movement op to pick is local, we could also make this always pick the simplest one - assert UPat(Ops.SHRINK, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)),)),)).match(b.lazydata, {}) - assert UPat(Ops.SHRINK, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)),)).match(c.lazydata, {}) + assert b.lazydata is c.lazydata + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {}) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index 777703a9a1..e82d401d54 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -3,7 +3,7 @@ import unittest, math import torch from tinygrad import Tensor from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, Ops +from tinygrad.ops import UOp from tinygrad.gradient import compute_gradient class TestGradient(unittest.TestCase): @@ -107,7 +107,7 @@ class TestTensorGradient(unittest.TestCase): class TestRealizeMeansRealize(unittest.TestCase): def test_randn_realizes(self): x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize() - self.assertEqual(x.lazydata.op, Ops.RESHAPE) + assert x.lazydata is not x.lazydata.base assert x.lazydata.is_realized #@unittest.expectedFailure @@ -115,7 +115,7 @@ class TestRealizeMeansRealize(unittest.TestCase): def test_uniform_realizes(self): x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize() print(x.lazydata) - self.assertEqual(x.lazydata.op, Ops.RESHAPE) + assert x.lazydata is not x.lazydata.base assert x.lazydata.is_realized # NOTE: even though it doesn't realize, this seems fine diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 7fc036aa69..1c128eca80 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -34,7 +34,7 @@ class TestTensorMutates(unittest.TestCase): is_pattern_uop(c.lazydata.base, realized_pattern) # NOTE: we keep movement ops on top of the buffer view is_pattern_uop(c.lazydata, UPat(Ops.BUFFER)) - is_pattern_uop(d.lazydata, UPat(Ops.RESHAPE, src=(realized_pattern,))) + is_pattern_uop(d.lazydata, UPat(Ops.VIEW, src=(realized_pattern,))) def test_reshape_is_same_child(self): a = Tensor([1,2,3]) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 979e5f55fd..1964d71cfc 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -72,6 +72,9 @@ sym = symbolic_simple+PatternMatcher([ # no COPY to same device, except clone (arg is True) (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + # copyin must be base + (UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \ + if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)), # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), @@ -405,14 +408,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # map tensors to buffer/const becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): - # NOTE: tensors can also map to a VIEW, if it's contiguous and we can reshape it it's fine - if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN and a.size == k.size and unwrap(v.st).contiguous: - becomes_map[k] = k.src[0] if k.op is Ops.ASSIGN else a.buf_uop.reshape(k.shape) + # NOTE: tensors can also map to a VIEW, we just apply this VIEW on top of the BUFFER + if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN: + becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st)) if v is k: continue - if v.base.op is Ops.BUFFER: - # VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it - if v.op is Ops.VIEW: v = next(iter(x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st)) - if k is not v: becomes_map[k] = v + if v.base.op is Ops.BUFFER: becomes_map[k] = v elif v.base.op is Ops.CONST: if all_int(v.shape): becomes_map[k] = v diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 8e35cbd267..e5cfb4aa11 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -88,7 +88,9 @@ class LARS(Optimizer): # classic momentum does post learning rate update if self.classic: g = g * r * self.lr if self.momentum: - self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required + # TODO: this contiguous is required for correctness becuase self.b[i] becomes a non contiguous view + # the scheduler should detect this and just insert contiguous + self.b[i].assign(self.momentum * self.b[i].contiguous() + g) # NOTE: self.b[i] is zero on the first run, no if required g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] # popular momentum does pre learning rate update if not self.classic: g = g * r * self.lr diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3a77241cf0..0314dca680 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -466,18 +466,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # otherwise it's just a RESHAPE(BUFFER) if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}") return UOp.new_buffer(device, size, dtype).reshape(shape) - def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp: - # if it's a shrink, do the shrink before the copy with CONTIGUOUS - if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) - # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st) - ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone) - op_arg = [] - mop = self - while mop is not self.base: - op_arg.append((mop.op, mop.arg)) - mop = mop.src[0] - for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg) - return ret + def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False): return UOp(Ops.COPY, self.dtype, (UOp(Ops.DEVICE, arg=device), self), clone) def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) @property def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None) diff --git a/tinygrad/spec.py b/tinygrad/spec.py index 88e1f94678..56105ab6c4 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -19,7 +19,7 @@ tensor_uop_spec = buffer_spec+PatternMatcher([ # "make things that can't be images not images" can change the buffer dtype # this is fine as long as it's a realized buffer and base dtypes match. ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)), - (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.CONST, Ops.DEVICE}),)), lambda: False), + (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}),)), lambda: False), # Tensor variable bindings (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),