mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
switch backward to use gradient [pr] (#8235)
* switch backward to use gradient [pr] * set device correctly, dedup * why does that fail? * add noop cast * simple backward * fix beautiful_mnist * touchups * set in compute_gradient * uop_count * uop_count was wrong * collections * no note * skip that test * update sched kernel counts * train mnist is 65 * fix metadata and gc * fixes * materialize_grads * no pathlib stuff * add contiguous_backward, fix bugs * add some realize * fix multi
This commit is contained in:
@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
|
||||
@@ -160,13 +160,14 @@ class TestIndexing(unittest.TestCase):
|
||||
# llama3 is 128256
|
||||
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
|
||||
emb = nn.Embedding(vocab_size, embed_size)
|
||||
emb_w = emb.weight.numpy()
|
||||
# TODO: why is a new realize needed here
|
||||
emb_w = emb.weight.realize().numpy()
|
||||
x = Tensor([1,2,3,4])
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
z = emb(x).realize()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 3)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -8,9 +8,11 @@ from tinygrad.ops import UOp
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def tensors_allocated():
|
||||
gc.collect()
|
||||
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
|
||||
|
||||
def bufs_allocated():
|
||||
gc.collect()
|
||||
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
|
||||
|
||||
class TestGC(unittest.TestCase):
|
||||
|
||||
@@ -1694,7 +1694,6 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
# should upcast the two Tensor.stacks
|
||||
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
|
||||
|
||||
@unittest.expectedFailure # requires contiguous folding
|
||||
def test_masked_upcast_wino_full(self):
|
||||
with Context(WINO=1):
|
||||
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
|
||||
@@ -997,7 +997,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"])
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02)
|
||||
|
||||
# llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local")
|
||||
|
||||
@@ -352,12 +352,15 @@ class TestOps(unittest.TestCase):
|
||||
def test_cmp_le(self): self._test_cmp(lambda x,y: x<=y)
|
||||
|
||||
def test_cmp_ne_backwards(self):
|
||||
# new grad zeroes these out
|
||||
"""
|
||||
t1 = torch.ones(4, requires_grad=True)
|
||||
t2 = torch.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (t1 != t2).sum().backward)
|
||||
tt1 = Tensor.ones(4, requires_grad=True)
|
||||
tt2 = Tensor.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward)
|
||||
"""
|
||||
tt = Tensor.randn(4, requires_grad=True)
|
||||
(tt*(tt != 0)).sum().backward()
|
||||
t = torch.tensor(tt.numpy(), requires_grad=True)
|
||||
@@ -365,12 +368,15 @@ class TestOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_cmp_lt_backwards(self):
|
||||
# new grad zeroes these out
|
||||
"""
|
||||
t1 = torch.ones(4, requires_grad=True)
|
||||
t2 = torch.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (t1 < t2).sum().backward)
|
||||
tt1 = Tensor.ones(4, requires_grad=True)
|
||||
tt2 = Tensor.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
|
||||
"""
|
||||
tt = Tensor.randn(4, requires_grad=True)
|
||||
(tt*(tt < 0)).sum().backward()
|
||||
t = torch.tensor(tt.numpy(), requires_grad=True)
|
||||
|
||||
@@ -609,6 +609,7 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(out, 2)
|
||||
|
||||
# multireduce spec
|
||||
@unittest.skip("these two Tensors are the same")
|
||||
def test_example_matmul(self):
|
||||
x = Tensor.eye(64, requires_grad=True)
|
||||
y = Tensor.eye(64, requires_grad=True)
|
||||
@@ -618,6 +619,15 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_contig(self):
|
||||
x = Tensor.eye(64, requires_grad=True).contiguous().realize()
|
||||
y = Tensor.eye(64, requires_grad=True).contiguous().realize()
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_same(self):
|
||||
x = Tensor.eye(64, requires_grad=True)
|
||||
z = x.matmul(x).sum()
|
||||
@@ -1050,7 +1060,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1071,7 +1081,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 8)
|
||||
check_schedule(opt.schedule_step(), 9)
|
||||
|
||||
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
|
||||
with Tensor.train():
|
||||
@@ -1082,7 +1092,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 10)
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
|
||||
def test_sgd_4convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1095,7 +1105,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 18)
|
||||
check_schedule(opt.schedule_step(), 21)
|
||||
|
||||
def test_sgd_4convs_fuse_conv_bw(self):
|
||||
with Tensor.train():
|
||||
@@ -1108,7 +1118,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15)
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 18)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_prefer_half_buffer(self):
|
||||
|
||||
@@ -770,8 +770,8 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
def test_complex_backward(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertEqual(out.lazydata.metadata.name, "sum")
|
||||
out.backward()
|
||||
|
||||
@@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase):
|
||||
|
||||
def _test_one_input_function(self, f:Callable, jf:Callable|None=None):
|
||||
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
||||
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x]
|
||||
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x]
|
||||
gf = jax.grad(f if jf is None else jf)
|
||||
|
||||
for val in [-5., -2.0, 0.0, 2.0, 5.]:
|
||||
@@ -24,7 +24,7 @@ class TestGradient(unittest.TestCase):
|
||||
def _test_two_input_function(self, f:Callable, jf:Callable|None=None):
|
||||
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
||||
y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float)
|
||||
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y])
|
||||
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y]))
|
||||
gx, gy = grads[x], grads[y]
|
||||
gf = jax.grad(f if jf is None else jf, argnums=(0, 1))
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast, Iterator
|
||||
import math, functools
|
||||
import math, functools, dataclasses
|
||||
from tinygrad.dtype import dtypes, sum_acc_dtype
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
|
||||
def reduce_gradient(ctx:UOp, ret:UOp):
|
||||
@@ -66,4 +66,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
if v is None: continue
|
||||
if k in grads: grads[k] = grads[k] + v
|
||||
else: grads[k] = v
|
||||
if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
|
||||
return grads
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.multi import get_multi_map
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
@@ -293,7 +293,6 @@ class Tensor(SimpleMathTrait):
|
||||
self.contiguous().realize().lazydata.base.realized.copyin(x._data())
|
||||
return self
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
||||
# NOTE: we allow cross device assign
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
@@ -901,7 +900,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
|
||||
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
|
||||
"""
|
||||
Compute the gradient of the targets with respect to self.
|
||||
|
||||
@@ -922,23 +921,14 @@ class Tensor(SimpleMathTrait):
|
||||
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
||||
if (y:=grads.get(x)) is None:
|
||||
if materialize_grads: y = x.const_like(0)
|
||||
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
||||
ret.append(y)
|
||||
rets.append(ret)
|
||||
# create returned Tensors
|
||||
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
||||
|
||||
def _deepwalk(self) -> list[Tensor]:
|
||||
def _walk(node:Tensor, visited:set[Tensor]):
|
||||
visited.add(node)
|
||||
# if tensor is not leaf, reset grad
|
||||
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
||||
if ctx:
|
||||
for i in cast(Function, node._ctx).parents:
|
||||
if i not in visited: yield from _walk(i, visited)
|
||||
yield node
|
||||
return list(_walk(self, set()))
|
||||
|
||||
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
|
||||
"""
|
||||
Propagates the gradient of a tensor backwards through the computation graph.
|
||||
@@ -950,30 +940,14 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.grad.numpy())
|
||||
```
|
||||
"""
|
||||
toposorted = self._deepwalk()
|
||||
if gradient is None:
|
||||
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
# this is "implicit gradient creation"
|
||||
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
|
||||
toposort_uop = self.lazydata.toposort
|
||||
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
||||
self.grad = gradient
|
||||
for t0 in reversed(toposorted):
|
||||
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
||||
ctx = cast(Function, t0._ctx)
|
||||
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
|
||||
grads = ctx.backward(t0.grad.lazydata)
|
||||
_METADATA.reset(token)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
assert t.lazydata in toposort_uop, f"grad uop must have a path from self\ngrad uop: {t.lazydata}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
if not retain_graph: del t0._ctx
|
||||
all_uops = self.lazydata.toposort
|
||||
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
||||
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
|
||||
# clear contexts
|
||||
for t in tensors_need_grad: t._ctx = None
|
||||
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
return self
|
||||
|
||||
# ***** movement low level ops *****
|
||||
@@ -3993,5 +3967,5 @@ def _metadata_wrapper(fn):
|
||||
|
||||
if TRACEMETA >= 1:
|
||||
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
||||
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
||||
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
|
||||
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|
||||
|
||||
Reference in New Issue
Block a user