mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
functions for llama trainer (#15045)
* functions for llama trainer * function there * axis match * fix multi * lil cleaner * there's a bug with HK_FLASH_ATTENTION * training functions * for commit
This commit is contained in:
@@ -5,6 +5,7 @@ export DEV=${DEV:-AMD}
|
||||
export EMULATE="AMD_CDNA4"
|
||||
export CHECK_OOB=0
|
||||
export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
|
||||
export DEVICE_IN_FUNCTION_BUG=1
|
||||
|
||||
export DEBUG=${DEBUG:-0}
|
||||
export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
|
||||
|
||||
@@ -129,6 +129,31 @@ class TestFunction(unittest.TestCase):
|
||||
a = Tensor([1., 2., 3.])
|
||||
np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.])
|
||||
|
||||
def test_nested_calls_backward(self):
|
||||
w = Tensor([[1., 2.], [3., 4.]]).contiguous().realize()
|
||||
@function
|
||||
def inner(x:Tensor) -> Tensor: return x + w
|
||||
@function
|
||||
def outer(a:Tensor, b:Tensor) -> Tensor: return inner(a.reshape(1,2) + b.reshape(1,2))
|
||||
|
||||
a = Tensor([1., 2.], requires_grad=True)
|
||||
b = Tensor([3., 4.], requires_grad=True)
|
||||
outer(a, b).sum().backward()
|
||||
np.testing.assert_allclose(a.grad.numpy(), [2., 2.])
|
||||
np.testing.assert_allclose(b.grad.numpy(), [2., 2.])
|
||||
|
||||
def test_unused_param_backward(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor, c:Tensor) -> Tensor: return a + c # b is unused
|
||||
|
||||
a = Tensor([1., 2., 3.], requires_grad=True)
|
||||
b = Tensor([4., 5., 6.], requires_grad=True)
|
||||
c = Tensor([7., 8., 9.], requires_grad=True)
|
||||
f(a, b, c).sum().backward()
|
||||
np.testing.assert_allclose(a.grad.numpy(), [1., 1., 1.])
|
||||
np.testing.assert_allclose(b.grad.numpy(), [0., 0., 0.])
|
||||
np.testing.assert_allclose(c.grad.numpy(), [1., 1., 1.])
|
||||
|
||||
def test_name(self):
|
||||
@function
|
||||
def f(a:Tensor) -> Tensor: return a + 1
|
||||
@@ -230,5 +255,77 @@ class TestFunctionMulti(unittest.TestCase):
|
||||
f(x).sum().backward()
|
||||
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.])
|
||||
|
||||
def test_call_axis(self):
|
||||
@function
|
||||
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
|
||||
|
||||
x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]).shard(self.devices_2, axis=0)
|
||||
w = Tensor([[1.,2.],[3.,4.]]).shard(self.devices_2, axis=None)
|
||||
result = f(x, w)
|
||||
# CALL output should inherit axis=0 from the sharded input
|
||||
self.assertEqual(result.uop.axis, 0)
|
||||
# reduce on the sharded axis should remove it
|
||||
self.assertIsNone(result.sum().uop.axis)
|
||||
|
||||
def test_call_axis_shard_inside(self):
|
||||
@function
|
||||
def f(x:Tensor, w:Tensor) -> Tensor:
|
||||
return x.shard(self.devices_2, axis=0) @ w.shard(self.devices_2, axis=None)
|
||||
|
||||
x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]])
|
||||
w = Tensor([[1.,2.],[3.,4.]])
|
||||
result = f(x, w)
|
||||
self.assertEqual(result.uop.axis, 0)
|
||||
np.testing.assert_allclose(result.numpy(), x.numpy() @ w.numpy())
|
||||
|
||||
def test_data_parallel_backward(self):
|
||||
@function
|
||||
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
|
||||
|
||||
x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]], requires_grad=True).shard(self.devices_2, axis=0)
|
||||
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(self.devices_2, axis=None)
|
||||
w.realize()
|
||||
f(x, w).sum().backward()
|
||||
# d/dx = ones @ w^T = [[1,3],[1,3],[1,3],[1,3]], but sum so ones(4,2) @ w^T? no:
|
||||
# L = sum(x @ w), dL/dx = ones(4,2) @ w^T... actually dL/d(xw) = ones(4,2), dL/dx = ones(4,2) @ w^T
|
||||
np.testing.assert_allclose(x.grad.numpy(), np.ones((4,2)) @ np.array([[1,3],[2,4]]))
|
||||
|
||||
def test_data_parallel_backward_4(self):
|
||||
devices_4 = tuple(f"CPU:{i}" for i in range(4))
|
||||
@function
|
||||
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
|
||||
|
||||
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
|
||||
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
|
||||
w.realize()
|
||||
f(x, w).sum().backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]]))
|
||||
|
||||
def test_data_parallel_backward_implicit(self):
|
||||
devices_4 = tuple(f"CPU:{i}" for i in range(4))
|
||||
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
|
||||
w.realize()
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x @ w
|
||||
|
||||
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
|
||||
f(x).sum().backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]]))
|
||||
|
||||
def test_data_parallel_backward_twice(self):
|
||||
devices_4 = tuple(f"CPU:{i}" for i in range(4))
|
||||
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
|
||||
w.realize()
|
||||
# pre-init grads like the training loop does
|
||||
w.grad = w.zeros_like().contiguous().realize()
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x @ w
|
||||
|
||||
expected = np.ones((8,2)) @ np.array([[1,3],[2,4]])
|
||||
for _ in range(2):
|
||||
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
|
||||
f(x).sum().backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -17,12 +17,13 @@ def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
|
||||
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
# auto-differentiate the function
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params))
|
||||
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
|
||||
ret: list[UOp|None] = [None]
|
||||
for i,p in enumerate(params):
|
||||
if p in grads:
|
||||
for i in range(len(args)):
|
||||
if (p:=params.get(i, None)) is not None and p in grads:
|
||||
# TODO: compact the args and remove unused ones
|
||||
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
|
||||
ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}"))
|
||||
else:
|
||||
ret.append(None)
|
||||
|
||||
@@ -164,10 +164,18 @@ def passthrough_multi(root:UOp, multi:UOp):
|
||||
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
||||
|
||||
def rewrite_into_call(call:UOp):
|
||||
return call.replace(src=(graph_rewrite(call.src[0], multi_pm, name="subcall"),)+call.src[1:]) if should_resolve_call(call) else None
|
||||
if not should_resolve_call(call): return None
|
||||
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")
|
||||
new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:])
|
||||
return call.replace(src=(new_body,)+new_args)
|
||||
|
||||
def param_to_multi(p:UOp):
|
||||
if p.axis is None: return None
|
||||
return UOp.param(p.arg, p.dtype, p.shard_shape, p._device).multi(p.axis)
|
||||
|
||||
# NOTE: this is the same pattern as Ops.UNROLL
|
||||
multi_pm = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="p"), param_to_multi),
|
||||
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), reshape_multi),
|
||||
|
||||
@@ -90,6 +90,7 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||
|
||||
dict_map = {x:args[x.arg] for x in params}
|
||||
for i, (p, a) in enumerate(dict_map.items()):
|
||||
if p.axis != a.axis: raise TypeError(f"arg {i} axis mismatch: expected {p.axis}, got {a.axis}")
|
||||
if p.max_shape != a.max_shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
return c.src[0].substitute(dict_map, walk=True)
|
||||
|
||||
@@ -163,7 +163,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# Check self first, then iterate backward_slice (avoids creating intermediate dict)
|
||||
return self.op in ops or any(x.op in ops for x in self.backward_slice)
|
||||
|
||||
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
|
||||
def toposort(self, gate:Callable|None=None, enter_calls=True) -> dict[UOp, None]:
|
||||
cache: dict[UOp, None] = {}
|
||||
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
|
||||
while stack:
|
||||
@@ -172,7 +172,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if not visited:
|
||||
if gate is None or gate(node):
|
||||
stack.append((node, True)) # push node back on stack to process after its srcs
|
||||
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
|
||||
for s in reversed(node.src if enter_calls or node.op is not Ops.CALL else node.src[1:]):
|
||||
stack.append((s, False)) # push srcs on the stack
|
||||
else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
|
||||
return cache
|
||||
|
||||
@@ -253,6 +254,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.RESHAPE:
|
||||
if self.src[0]._shape is None: return self.marg
|
||||
|
||||
# MULTI marker (axis info in PARAM sources) has no shape
|
||||
case Ops.MULTI if len(self.src) == 0: return None
|
||||
|
||||
# movement ops change the shape
|
||||
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
|
||||
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
|
||||
@@ -514,6 +518,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK
|
||||
if self.op is Ops.COPY: return None
|
||||
if self.op is Ops.MULTI: return self.arg
|
||||
# PARAM: axis is stored as a MULTI source
|
||||
if self.op is Ops.PARAM:
|
||||
for s in self.src:
|
||||
if s.op is Ops.MULTI: return s.arg
|
||||
return None
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
||||
if len(self.src) == 0: return None
|
||||
@@ -867,9 +876,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def param_like(self, slot:int):
|
||||
if self.op is Ops.BIND:
|
||||
return UOp.param(slot, self.dtype, self._shape, self._device, self._min_max, self.src[0].arg[0])
|
||||
if self.axis is not None:
|
||||
return UOp.param(slot, self.dtype, self.shard_shape, self._device).multi(self.axis)
|
||||
return UOp.param(slot, self.dtype, self._shape, self._device)
|
||||
p = UOp.param(slot, self.dtype, self._shape, self._device)
|
||||
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
|
||||
return p
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp:
|
||||
# TODO: reenable this after ENCDEC is fixed
|
||||
|
||||
Reference in New Issue
Block a user