switch backward to use gradient [pr] (#8235)

* switch backward to use gradient [pr]

* set device correctly, dedup

* why does that fail?

* add noop cast

* simple backward

* fix beautiful_mnist

* touchups

* set in compute_gradient

* uop_count

* uop_count was wrong

* collections

* no note

* skip that test

* update sched kernel counts

* train mnist is 65

* fix metadata and gc

* fixes

* materialize_grads

* no pathlib stuff

* add contiguous_backward, fix bugs

* add some realize

* fix multi
This commit is contained in:
George Hotz
2025-01-26 09:12:16 +09:00
committed by GitHub
parent 0ffd572e1e
commit b4bf6a7dea
11 changed files with 50 additions and 57 deletions

View File

@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63)
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65)
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
def test_train_cifar(self):

View File

@@ -160,13 +160,14 @@ class TestIndexing(unittest.TestCase):
# llama3 is 128256
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
emb = nn.Embedding(vocab_size, embed_size)
emb_w = emb.weight.numpy()
# TODO: why is a new realize needed here
emb_w = emb.weight.realize().numpy()
x = Tensor([1,2,3,4])
with Context(NOOPT=noopt, FUSE_ARANGE=1):
GlobalCounters.reset()
z = emb(x).realize()
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
self.assertEqual(GlobalCounters.kernel_count, 3)
self.assertEqual(GlobalCounters.kernel_count, 2)
if getenv("CHECK", 1):
import torch
with torch.no_grad():

View File

@@ -8,9 +8,11 @@ from tinygrad.ops import UOp
from tinygrad.tensor import Tensor
def tensors_allocated():
gc.collect()
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
def bufs_allocated():
gc.collect()
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
class TestGC(unittest.TestCase):

View File

@@ -1694,7 +1694,6 @@ class TestHandCodedOpts(unittest.TestCase):
# should upcast the two Tensor.stacks
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
@unittest.expectedFailure # requires contiguous folding
def test_masked_upcast_wino_full(self):
with Context(WINO=1):
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()

View File

@@ -997,7 +997,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"])
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02)
# llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local")

View File

@@ -352,12 +352,15 @@ class TestOps(unittest.TestCase):
def test_cmp_le(self): self._test_cmp(lambda x,y: x<=y)
def test_cmp_ne_backwards(self):
# new grad zeroes these out
"""
t1 = torch.ones(4, requires_grad=True)
t2 = torch.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (t1 != t2).sum().backward)
tt1 = Tensor.ones(4, requires_grad=True)
tt2 = Tensor.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward)
"""
tt = Tensor.randn(4, requires_grad=True)
(tt*(tt != 0)).sum().backward()
t = torch.tensor(tt.numpy(), requires_grad=True)
@@ -365,12 +368,15 @@ class TestOps(unittest.TestCase):
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5)
def test_cmp_lt_backwards(self):
# new grad zeroes these out
"""
t1 = torch.ones(4, requires_grad=True)
t2 = torch.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (t1 < t2).sum().backward)
tt1 = Tensor.ones(4, requires_grad=True)
tt2 = Tensor.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
"""
tt = Tensor.randn(4, requires_grad=True)
(tt*(tt < 0)).sum().backward()
t = torch.tensor(tt.numpy(), requires_grad=True)

View File

@@ -609,6 +609,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 2)
# multireduce spec
@unittest.skip("these two Tensors are the same")
def test_example_matmul(self):
x = Tensor.eye(64, requires_grad=True)
y = Tensor.eye(64, requires_grad=True)
@@ -618,6 +619,15 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_contig(self):
x = Tensor.eye(64, requires_grad=True).contiguous().realize()
y = Tensor.eye(64, requires_grad=True).contiguous().realize()
z = y.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_same(self):
x = Tensor.eye(64, requires_grad=True)
z = x.matmul(x).sum()
@@ -1050,7 +1060,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 13)
check_schedule(opt.schedule_step(), 14)
def test_sgd_conv_fuse(self):
with Tensor.train():
@@ -1071,7 +1081,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 8)
check_schedule(opt.schedule_step(), 9)
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
@@ -1082,7 +1092,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self):
with Tensor.train():
@@ -1095,7 +1105,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 18)
check_schedule(opt.schedule_step(), 21)
def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train():
@@ -1108,7 +1118,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15)
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 18)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):

View File

@@ -770,8 +770,8 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.lazydata.metadata.name, "sum")
out.backward()

View File

@@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase):
def _test_one_input_function(self, f:Callable, jf:Callable|None=None):
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x]
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x]
gf = jax.grad(f if jf is None else jf)
for val in [-5., -2.0, 0.0, 2.0, 5.]:
@@ -24,7 +24,7 @@ class TestGradient(unittest.TestCase):
def _test_two_input_function(self, f:Callable, jf:Callable|None=None):
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float)
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y])
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y]))
gx, gy = grads[x], grads[y]
gf = jax.grad(f if jf is None else jf, argnums=(0, 1))

View File

@@ -1,7 +1,7 @@
from typing import cast, Iterator
import math, functools
import math, functools, dataclasses
from tinygrad.dtype import dtypes, sum_acc_dtype
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
def reduce_gradient(ctx:UOp, ret:UOp):
@@ -66,4 +66,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
if v is None: continue
if k in grads: grads[k] = grads[k] + v
else: grads[k] = v
if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
return grads

View File

@@ -1,11 +1,11 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.multi import get_multi_map
from tinygrad.gradient import compute_gradient
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
@@ -293,7 +293,6 @@ class Tensor(SimpleMathTrait):
self.contiguous().realize().lazydata.base.realized.copyin(x._data())
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
@@ -901,7 +900,7 @@ class Tensor(SimpleMathTrait):
# ***** toposort and backward pass *****
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
"""
Compute the gradient of the targets with respect to self.
@@ -922,23 +921,14 @@ class Tensor(SimpleMathTrait):
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
if (y:=grads.get(x)) is None:
if materialize_grads: y = x.const_like(0)
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
ret.append(y)
rets.append(ret)
# create returned Tensors
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
def _deepwalk(self) -> list[Tensor]:
def _walk(node:Tensor, visited:set[Tensor]):
visited.add(node)
# if tensor is not leaf, reset grad
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
if ctx:
for i in cast(Function, node._ctx).parents:
if i not in visited: yield from _walk(i, visited)
yield node
return list(_walk(self, set()))
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
"""
Propagates the gradient of a tensor backwards through the computation graph.
@@ -950,30 +940,14 @@ class Tensor(SimpleMathTrait):
print(t.grad.numpy())
```
"""
toposorted = self._deepwalk()
if gradient is None:
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
toposort_uop = self.lazydata.toposort
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
self.grad = gradient
for t0 in reversed(toposorted):
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
ctx = cast(Function, t0._ctx)
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
grads = ctx.backward(t0.grad.lazydata)
_METADATA.reset(token)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(ctx.parents) == 1 else grads)]
for t, g in zip(ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
assert t.lazydata in toposort_uop, f"grad uop must have a path from self\ngrad uop: {t.lazydata}"
t.grad = g if t.grad is None else (t.grad + g)
if not retain_graph: del t0._ctx
all_uops = self.lazydata.toposort
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
# clear contexts
for t in tensors_need_grad: t._ctx = None
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
return self
# ***** movement low level ops *****
@@ -3993,5 +3967,5 @@ def _metadata_wrapper(fn):
if TRACEMETA >= 1:
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))