diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index a213e5ec88..7aa3253a07 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase): loss.backward() optimizer.step() - helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63) + helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65) @unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow") def test_train_cifar(self): diff --git a/test/test_arange.py b/test/test_arange.py index 07512ae1b6..2229ac847f 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -160,13 +160,14 @@ class TestIndexing(unittest.TestCase): # llama3 is 128256 vocab_size, embed_size = (10, 3) if CI else (32000, 4096) emb = nn.Embedding(vocab_size, embed_size) - emb_w = emb.weight.numpy() + # TODO: why is a new realize needed here + emb_w = emb.weight.realize().numpy() x = Tensor([1,2,3,4]) with Context(NOOPT=noopt, FUSE_ARANGE=1): GlobalCounters.reset() z = emb(x).realize() self.assertLessEqual(GlobalCounters.global_ops, op_limit) - self.assertEqual(GlobalCounters.kernel_count, 3) + self.assertEqual(GlobalCounters.kernel_count, 2) if getenv("CHECK", 1): import torch with torch.no_grad(): diff --git a/test/test_gc.py b/test/test_gc.py index 010f59039f..0929a81394 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -8,9 +8,11 @@ from tinygrad.ops import UOp from tinygrad.tensor import Tensor def tensors_allocated(): + gc.collect() return sum([isinstance(x, Tensor) for x in gc.get_objects()]) def bufs_allocated(): + gc.collect() return sum([isinstance(x, Buffer) for x in gc.get_objects()]) class TestGC(unittest.TestCase): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index da2995d37d..c0bdc2c6f8 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1694,7 +1694,6 @@ class TestHandCodedOpts(unittest.TestCase): # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 - @unittest.expectedFailure # requires contiguous folding def test_masked_upcast_wino_full(self): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 80ca7fd6e8..8570dac761 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -997,7 +997,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] - helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"]) + helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02) # llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") diff --git a/test/test_ops.py b/test/test_ops.py index d342a71df4..020f168b8d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -352,12 +352,15 @@ class TestOps(unittest.TestCase): def test_cmp_le(self): self._test_cmp(lambda x,y: x<=y) def test_cmp_ne_backwards(self): + # new grad zeroes these out + """ t1 = torch.ones(4, requires_grad=True) t2 = torch.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (t1 != t2).sum().backward) tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward) + """ tt = Tensor.randn(4, requires_grad=True) (tt*(tt != 0)).sum().backward() t = torch.tensor(tt.numpy(), requires_grad=True) @@ -365,12 +368,15 @@ class TestOps(unittest.TestCase): np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5) def test_cmp_lt_backwards(self): + # new grad zeroes these out + """ t1 = torch.ones(4, requires_grad=True) t2 = torch.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (t1 < t2).sum().backward) tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) + """ tt = Tensor.randn(4, requires_grad=True) (tt*(tt < 0)).sum().backward() t = torch.tensor(tt.numpy(), requires_grad=True) diff --git a/test/test_schedule.py b/test/test_schedule.py index d852d2345e..dee2a7ae78 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -609,6 +609,7 @@ class TestSchedule(unittest.TestCase): check_schedule(out, 2) # multireduce spec + @unittest.skip("these two Tensors are the same") def test_example_matmul(self): x = Tensor.eye(64, requires_grad=True) y = Tensor.eye(64, requires_grad=True) @@ -618,6 +619,15 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), np.ones((64,64))) + def test_example_matmul_contig(self): + x = Tensor.eye(64, requires_grad=True).contiguous().realize() + y = Tensor.eye(64, requires_grad=True).contiguous().realize() + z = y.matmul(x).sum() + z.backward() + out = x.grad.contiguous() + run_schedule(check_schedule(out, 2)) + np.testing.assert_allclose(out.numpy(), np.ones((64,64))) + def test_example_matmul_same(self): x = Tensor.eye(64, requires_grad=True) z = x.matmul(x).sum() @@ -1050,7 +1060,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 13) + check_schedule(opt.schedule_step(), 14) def test_sgd_conv_fuse(self): with Tensor.train(): @@ -1071,7 +1081,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2])) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 8) + check_schedule(opt.schedule_step(), 9) def test_fold_2convs_sgd_nesterov_momentum_wd(self): with Tensor.train(): @@ -1082,7 +1092,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 10) + check_schedule(opt.schedule_step(), 11) def test_sgd_4convs_fuse(self): with Tensor.train(): @@ -1095,7 +1105,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 18) + check_schedule(opt.schedule_step(), 21) def test_sgd_4convs_fuse_conv_bw(self): with Tensor.train(): @@ -1108,7 +1118,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15) + with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 18) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): diff --git a/test/test_tensor.py b/test/test_tensor.py index d73e55e36a..b1b90ea44b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -770,8 +770,8 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) def test_complex_backward(self): - x = Tensor.rand(3, requires_grad=True) - y = Tensor.rand(3, requires_grad=True) + x = Tensor.rand(3, requires_grad=True).realize() + y = Tensor.rand(3, requires_grad=True).realize() out = (x.relu() * y.sigmoid()).sum() self.assertEqual(out.lazydata.metadata.name, "sum") out.backward() diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index a9a41eace0..f7a99982fa 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase): def _test_one_input_function(self, f:Callable, jf:Callable|None=None): x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) - gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x] + gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x] gf = jax.grad(f if jf is None else jf) for val in [-5., -2.0, 0.0, 2.0, 5.]: @@ -24,7 +24,7 @@ class TestGradient(unittest.TestCase): def _test_two_input_function(self, f:Callable, jf:Callable|None=None): x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float) - grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y]) + grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y])) gx, gy = grads[x], grads[y] gf = jax.grad(f if jf is None else jf, argnums=(0, 1)) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 23df056299..86c9f0fa63 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -1,7 +1,7 @@ from typing import cast, Iterator -import math, functools +import math, functools, dataclasses from tinygrad.dtype import dtypes, sum_acc_dtype -from tinygrad.ops import UOp, PatternMatcher, UPat, Ops +from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata from tinygrad.helpers import argsort def reduce_gradient(ctx:UOp, ret:UOp): @@ -66,4 +66,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp if v is None: continue if k in grads: grads[k] = grads[k] + v else: grads[k] = v + if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True) return grads diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2c47162c26..826ade9242 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,11 +1,11 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref +import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap +from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap from tinygrad.multi import get_multi_map from tinygrad.gradient import compute_gradient from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element @@ -293,7 +293,6 @@ class Tensor(SimpleMathTrait): self.contiguous().realize().lazydata.base.realized.copyin(x._data()) return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) - if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") if self.lazydata is x.lazydata: return self # a self assign is a NOOP # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" @@ -901,7 +900,7 @@ class Tensor(SimpleMathTrait): # ***** toposort and backward pass ***** - def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]: + def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]: """ Compute the gradient of the targets with respect to self. @@ -922,23 +921,14 @@ class Tensor(SimpleMathTrait): grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops)) ret = [] for x in target_uops: - if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") + if (y:=grads.get(x)) is None: + if materialize_grads: y = x.const_like(0) + else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") ret.append(y) rets.append(ret) # create returned Tensors return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] - def _deepwalk(self) -> list[Tensor]: - def _walk(node:Tensor, visited:set[Tensor]): - visited.add(node) - # if tensor is not leaf, reset grad - if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None - if ctx: - for i in cast(Function, node._ctx).parents: - if i not in visited: yield from _walk(i, visited) - yield node - return list(_walk(self, set())) - def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor: """ Propagates the gradient of a tensor backwards through the computation graph. @@ -950,30 +940,14 @@ class Tensor(SimpleMathTrait): print(t.grad.numpy()) ``` """ - toposorted = self._deepwalk() - if gradient is None: - assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" - # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous - # this is "implicit gradient creation" - gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) - - toposort_uop = self.lazydata.toposort - assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}" - self.grad = gradient - for t0 in reversed(toposorted): - if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad") - ctx = cast(Function, t0._ctx) - token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None) - grads = ctx.backward(t0.grad.lazydata) - _METADATA.reset(token) - grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None - for g in ([grads] if len(ctx.parents) == 1 else grads)] - for t, g in zip(ctx.parents, grads): - if g is not None and t.requires_grad: - assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" - assert t.lazydata in toposort_uop, f"grad uop must have a path from self\ngrad uop: {t.lazydata}" - t.grad = g if t.grad is None else (t.grad + g) - if not retain_graph: del t0._ctx + all_uops = self.lazydata.toposort + tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ + t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad] + # clear contexts + for t in tensors_need_grad: t._ctx = None + for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): + assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" + t.grad = g if t.grad is None else (t.grad + g) return self # ***** movement low level ops ***** @@ -3993,5 +3967,5 @@ def _metadata_wrapper(fn): if TRACEMETA >= 1: for name, fn in inspect.getmembers(Tensor, inspect.isfunction): - if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue + if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))