diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh index c729d1b947..89c48e4d6d 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh @@ -5,6 +5,7 @@ export DEV=${DEV:-AMD} export EMULATE="AMD_CDNA4" export CHECK_OOB=0 export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000 +export DEVICE_IN_FUNCTION_BUG=1 export DEBUG=${DEBUG:-0} export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1} diff --git a/test/unit/test_function.py b/test/unit/test_function.py index 2a3fa6a1e2..e3cf7335de 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -129,6 +129,31 @@ class TestFunction(unittest.TestCase): a = Tensor([1., 2., 3.]) np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.]) + def test_nested_calls_backward(self): + w = Tensor([[1., 2.], [3., 4.]]).contiguous().realize() + @function + def inner(x:Tensor) -> Tensor: return x + w + @function + def outer(a:Tensor, b:Tensor) -> Tensor: return inner(a.reshape(1,2) + b.reshape(1,2)) + + a = Tensor([1., 2.], requires_grad=True) + b = Tensor([3., 4.], requires_grad=True) + outer(a, b).sum().backward() + np.testing.assert_allclose(a.grad.numpy(), [2., 2.]) + np.testing.assert_allclose(b.grad.numpy(), [2., 2.]) + + def test_unused_param_backward(self): + @function + def f(a:Tensor, b:Tensor, c:Tensor) -> Tensor: return a + c # b is unused + + a = Tensor([1., 2., 3.], requires_grad=True) + b = Tensor([4., 5., 6.], requires_grad=True) + c = Tensor([7., 8., 9.], requires_grad=True) + f(a, b, c).sum().backward() + np.testing.assert_allclose(a.grad.numpy(), [1., 1., 1.]) + np.testing.assert_allclose(b.grad.numpy(), [0., 0., 0.]) + np.testing.assert_allclose(c.grad.numpy(), [1., 1., 1.]) + def test_name(self): @function def f(a:Tensor) -> Tensor: return a + 1 @@ -230,5 +255,77 @@ class TestFunctionMulti(unittest.TestCase): f(x).sum().backward() np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.]) + def test_call_axis(self): + @function + def f(x:Tensor, w:Tensor) -> Tensor: return x @ w + + x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]).shard(self.devices_2, axis=0) + w = Tensor([[1.,2.],[3.,4.]]).shard(self.devices_2, axis=None) + result = f(x, w) + # CALL output should inherit axis=0 from the sharded input + self.assertEqual(result.uop.axis, 0) + # reduce on the sharded axis should remove it + self.assertIsNone(result.sum().uop.axis) + + def test_call_axis_shard_inside(self): + @function + def f(x:Tensor, w:Tensor) -> Tensor: + return x.shard(self.devices_2, axis=0) @ w.shard(self.devices_2, axis=None) + + x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]) + w = Tensor([[1.,2.],[3.,4.]]) + result = f(x, w) + self.assertEqual(result.uop.axis, 0) + np.testing.assert_allclose(result.numpy(), x.numpy() @ w.numpy()) + + def test_data_parallel_backward(self): + @function + def f(x:Tensor, w:Tensor) -> Tensor: return x @ w + + x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]], requires_grad=True).shard(self.devices_2, axis=0) + w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(self.devices_2, axis=None) + w.realize() + f(x, w).sum().backward() + # d/dx = ones @ w^T = [[1,3],[1,3],[1,3],[1,3]], but sum so ones(4,2) @ w^T? no: + # L = sum(x @ w), dL/dx = ones(4,2) @ w^T... actually dL/d(xw) = ones(4,2), dL/dx = ones(4,2) @ w^T + np.testing.assert_allclose(x.grad.numpy(), np.ones((4,2)) @ np.array([[1,3],[2,4]])) + + def test_data_parallel_backward_4(self): + devices_4 = tuple(f"CPU:{i}" for i in range(4)) + @function + def f(x:Tensor, w:Tensor) -> Tensor: return x @ w + + x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) + w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) + w.realize() + f(x, w).sum().backward() + np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]])) + + def test_data_parallel_backward_implicit(self): + devices_4 = tuple(f"CPU:{i}" for i in range(4)) + w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) + w.realize() + @function + def f(x:Tensor) -> Tensor: return x @ w + + x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) + f(x).sum().backward() + np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]])) + + def test_data_parallel_backward_twice(self): + devices_4 = tuple(f"CPU:{i}" for i in range(4)) + w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) + w.realize() + # pre-init grads like the training loop does + w.grad = w.zeros_like().contiguous().realize() + @function + def f(x:Tensor) -> Tensor: return x @ w + + expected = np.ones((8,2)) @ np.array([[1,3],[2,4]]) + for _ in range(2): + x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) + f(x).sum().backward() + np.testing.assert_allclose(x.grad.numpy(), expected) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 0a421d8bd6..d40b6140ec 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -17,12 +17,13 @@ def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]: if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k) # auto-differentiate the function fxn, args = k.src[0], k.src[1:] - params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) - grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params)) + params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM} + grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values())) ret: list[UOp|None] = [None] - for i,p in enumerate(params): - if p in grads: + for i in range(len(args)): + if (p:=params.get(i, None)) is not None and p in grads: # TODO: compact the args and remove unused ones + assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad" ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}")) else: ret.append(None) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index b92bc38dc8..73a5ff92ab 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -164,10 +164,18 @@ def passthrough_multi(root:UOp, multi:UOp): return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis) def rewrite_into_call(call:UOp): - return call.replace(src=(graph_rewrite(call.src[0], multi_pm, name="subcall"),)+call.src[1:]) if should_resolve_call(call) else None + if not should_resolve_call(call): return None + new_body = graph_rewrite(call.src[0], multi_pm, name="subcall") + new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:]) + return call.replace(src=(new_body,)+new_args) + +def param_to_multi(p:UOp): + if p.axis is None: return None + return UOp.param(p.arg, p.dtype, p.shard_shape, p._device).multi(p.axis) # NOTE: this is the same pattern as Ops.UNROLL multi_pm = PatternMatcher([ + (UPat(Ops.PARAM, name="p"), param_to_multi), (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi), (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi), (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), UPat()), name="root"), reshape_multi), diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 2ddde4ecff..4cd892b6e6 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -90,6 +90,7 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None: dict_map = {x:args[x.arg] for x in params} for i, (p, a) in enumerate(dict_map.items()): + if p.axis != a.axis: raise TypeError(f"arg {i} axis mismatch: expected {p.axis}, got {a.axis}") if p.max_shape != a.max_shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}") if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}") return c.src[0].substitute(dict_map, walk=True) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 12adde4866..b101d83096 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -163,7 +163,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # Check self first, then iterate backward_slice (avoids creating intermediate dict) return self.op in ops or any(x.op in ops for x in self.backward_slice) - def toposort(self, gate:Callable|None=None) -> dict[UOp, None]: + def toposort(self, gate:Callable|None=None, enter_calls=True) -> dict[UOp, None]: cache: dict[UOp, None] = {} stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag) while stack: @@ -172,7 +172,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if not visited: if gate is None or gate(node): stack.append((node, True)) # push node back on stack to process after its srcs - for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack + for s in reversed(node.src if enter_calls or node.op is not Ops.CALL else node.src[1:]): + stack.append((s, False)) # push srcs on the stack else: cache[node] = None # second time i'm seeing this node, add it to returned toposort return cache @@ -253,6 +254,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.RESHAPE: if self.src[0]._shape is None: return self.marg + # MULTI marker (axis info in PARAM sources) has no shape + case Ops.MULTI if len(self.src) == 0: return None + # movement ops change the shape # NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}): @@ -514,6 +518,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK if self.op is Ops.COPY: return None if self.op is Ops.MULTI: return self.arg + # PARAM: axis is stored as a MULTI source + if self.op is Ops.PARAM: + for s in self.src: + if s.op is Ops.MULTI: return s.arg + return None # NOTE: they all have to share an axis, we always choose [-1] if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None if len(self.src) == 0: return None @@ -867,9 +876,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def param_like(self, slot:int): if self.op is Ops.BIND: return UOp.param(slot, self.dtype, self._shape, self._device, self._min_max, self.src[0].arg[0]) - if self.axis is not None: - return UOp.param(slot, self.dtype, self.shard_shape, self._device).multi(self.axis) - return UOp.param(slot, self.dtype, self._shape, self._device) + p = UOp.param(slot, self.dtype, self._shape, self._device) + if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),)) + return p def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp: # TODO: reenable this after ENCDEC is fixed