functions for llama trainer (#15045)

* functions for llama trainer

* function there

* axis match

* fix multi

* lil cleaner

* there's a bug with HK_FLASH_ATTENTION

* training functions

* for commit
This commit is contained in:
George Hotz
2026-02-28 12:15:18 +08:00
committed by GitHub
parent 9b4ba3f838
commit bb84e389cf
6 changed files with 127 additions and 10 deletions

View File

@@ -5,6 +5,7 @@ export DEV=${DEV:-AMD}
export EMULATE="AMD_CDNA4"
export CHECK_OOB=0
export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
export DEVICE_IN_FUNCTION_BUG=1
export DEBUG=${DEBUG:-0}
export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}

View File

@@ -129,6 +129,31 @@ class TestFunction(unittest.TestCase):
a = Tensor([1., 2., 3.])
np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.])
def test_nested_calls_backward(self):
w = Tensor([[1., 2.], [3., 4.]]).contiguous().realize()
@function
def inner(x:Tensor) -> Tensor: return x + w
@function
def outer(a:Tensor, b:Tensor) -> Tensor: return inner(a.reshape(1,2) + b.reshape(1,2))
a = Tensor([1., 2.], requires_grad=True)
b = Tensor([3., 4.], requires_grad=True)
outer(a, b).sum().backward()
np.testing.assert_allclose(a.grad.numpy(), [2., 2.])
np.testing.assert_allclose(b.grad.numpy(), [2., 2.])
def test_unused_param_backward(self):
@function
def f(a:Tensor, b:Tensor, c:Tensor) -> Tensor: return a + c # b is unused
a = Tensor([1., 2., 3.], requires_grad=True)
b = Tensor([4., 5., 6.], requires_grad=True)
c = Tensor([7., 8., 9.], requires_grad=True)
f(a, b, c).sum().backward()
np.testing.assert_allclose(a.grad.numpy(), [1., 1., 1.])
np.testing.assert_allclose(b.grad.numpy(), [0., 0., 0.])
np.testing.assert_allclose(c.grad.numpy(), [1., 1., 1.])
def test_name(self):
@function
def f(a:Tensor) -> Tensor: return a + 1
@@ -230,5 +255,77 @@ class TestFunctionMulti(unittest.TestCase):
f(x).sum().backward()
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.])
def test_call_axis(self):
@function
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]).shard(self.devices_2, axis=0)
w = Tensor([[1.,2.],[3.,4.]]).shard(self.devices_2, axis=None)
result = f(x, w)
# CALL output should inherit axis=0 from the sharded input
self.assertEqual(result.uop.axis, 0)
# reduce on the sharded axis should remove it
self.assertIsNone(result.sum().uop.axis)
def test_call_axis_shard_inside(self):
@function
def f(x:Tensor, w:Tensor) -> Tensor:
return x.shard(self.devices_2, axis=0) @ w.shard(self.devices_2, axis=None)
x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]])
w = Tensor([[1.,2.],[3.,4.]])
result = f(x, w)
self.assertEqual(result.uop.axis, 0)
np.testing.assert_allclose(result.numpy(), x.numpy() @ w.numpy())
def test_data_parallel_backward(self):
@function
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]], requires_grad=True).shard(self.devices_2, axis=0)
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(self.devices_2, axis=None)
w.realize()
f(x, w).sum().backward()
# d/dx = ones @ w^T = [[1,3],[1,3],[1,3],[1,3]], but sum so ones(4,2) @ w^T? no:
# L = sum(x @ w), dL/dx = ones(4,2) @ w^T... actually dL/d(xw) = ones(4,2), dL/dx = ones(4,2) @ w^T
np.testing.assert_allclose(x.grad.numpy(), np.ones((4,2)) @ np.array([[1,3],[2,4]]))
def test_data_parallel_backward_4(self):
devices_4 = tuple(f"CPU:{i}" for i in range(4))
@function
def f(x:Tensor, w:Tensor) -> Tensor: return x @ w
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
w.realize()
f(x, w).sum().backward()
np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]]))
def test_data_parallel_backward_implicit(self):
devices_4 = tuple(f"CPU:{i}" for i in range(4))
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
w.realize()
@function
def f(x:Tensor) -> Tensor: return x @ w
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
f(x).sum().backward()
np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]]))
def test_data_parallel_backward_twice(self):
devices_4 = tuple(f"CPU:{i}" for i in range(4))
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
w.realize()
# pre-init grads like the training loop does
w.grad = w.zeros_like().contiguous().realize()
@function
def f(x:Tensor) -> Tensor: return x @ w
expected = np.ones((8,2)) @ np.array([[1,3],[2,4]])
for _ in range(2):
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
f(x).sum().backward()
np.testing.assert_allclose(x.grad.numpy(), expected)
if __name__ == '__main__':
unittest.main()

View File

@@ -17,12 +17,13 @@ def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
# auto-differentiate the function
fxn, args = k.src[0], k.src[1:]
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params))
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
ret: list[UOp|None] = [None]
for i,p in enumerate(params):
if p in grads:
for i in range(len(args)):
if (p:=params.get(i, None)) is not None and p in grads:
# TODO: compact the args and remove unused ones
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}"))
else:
ret.append(None)

View File

@@ -164,10 +164,18 @@ def passthrough_multi(root:UOp, multi:UOp):
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
def rewrite_into_call(call:UOp):
return call.replace(src=(graph_rewrite(call.src[0], multi_pm, name="subcall"),)+call.src[1:]) if should_resolve_call(call) else None
if not should_resolve_call(call): return None
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")
new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:])
return call.replace(src=(new_body,)+new_args)
def param_to_multi(p:UOp):
if p.axis is None: return None
return UOp.param(p.arg, p.dtype, p.shard_shape, p._device).multi(p.axis)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([
(UPat(Ops.PARAM, name="p"), param_to_multi),
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), reshape_multi),

View File

@@ -90,6 +90,7 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
dict_map = {x:args[x.arg] for x in params}
for i, (p, a) in enumerate(dict_map.items()):
if p.axis != a.axis: raise TypeError(f"arg {i} axis mismatch: expected {p.axis}, got {a.axis}")
if p.max_shape != a.max_shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
return c.src[0].substitute(dict_map, walk=True)

View File

@@ -163,7 +163,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# Check self first, then iterate backward_slice (avoids creating intermediate dict)
return self.op in ops or any(x.op in ops for x in self.backward_slice)
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
def toposort(self, gate:Callable|None=None, enter_calls=True) -> dict[UOp, None]:
cache: dict[UOp, None] = {}
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
while stack:
@@ -172,7 +172,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if not visited:
if gate is None or gate(node):
stack.append((node, True)) # push node back on stack to process after its srcs
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
for s in reversed(node.src if enter_calls or node.op is not Ops.CALL else node.src[1:]):
stack.append((s, False)) # push srcs on the stack
else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
return cache
@@ -253,6 +254,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.RESHAPE:
if self.src[0]._shape is None: return self.marg
# MULTI marker (axis info in PARAM sources) has no shape
case Ops.MULTI if len(self.src) == 0: return None
# movement ops change the shape
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
@@ -514,6 +518,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK
if self.op is Ops.COPY: return None
if self.op is Ops.MULTI: return self.arg
# PARAM: axis is stored as a MULTI source
if self.op is Ops.PARAM:
for s in self.src:
if s.op is Ops.MULTI: return s.arg
return None
# NOTE: they all have to share an axis, we always choose [-1]
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
if len(self.src) == 0: return None
@@ -867,9 +876,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def param_like(self, slot:int):
if self.op is Ops.BIND:
return UOp.param(slot, self.dtype, self._shape, self._device, self._min_max, self.src[0].arg[0])
if self.axis is not None:
return UOp.param(slot, self.dtype, self.shard_shape, self._device).multi(self.axis)
return UOp.param(slot, self.dtype, self._shape, self._device)
p = UOp.param(slot, self.dtype, self._shape, self._device)
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
return p
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp:
# TODO: reenable this after ENCDEC is fixed