From e82ba1454b0ddbe9e994c7416dc223bd5334b99a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 24 Jan 2025 13:28:55 +0900 Subject: [PATCH] MultiLazyBuffer is UOp [pr] (#8662) * MultiLazyBuffer is UOp [pr] * this is new mlb * this is the idea * progress * multitensor works * more movement ops * this * MultiLazyBuffer is UOp * cleanups * multi axis * fix more tests * work * not that * add multi grad and move shard to ops * mops not views * no double contig * sweet, all mt tests passing * port old logic * remove lbs * fix realized * whitespace * assign tweak * test_assign_kv_cache_multi passes * fix is_realized * fix JIT for multi * just a few more lines i'll pay them back soon i swear please bro just a few more * no split reduceop for multi --- .github/workflows/test.yml | 4 +- examples/hlb_cifar10.py | 3 - test/test_multitensor.py | 81 +++++++++---- tinygrad/engine/jit.py | 5 +- tinygrad/engine/schedule.py | 3 +- tinygrad/gradient.py | 2 +- tinygrad/multi.py | 228 ++++++++++++++++++------------------ tinygrad/nn/state.py | 7 +- tinygrad/ops.py | 67 +++++++++-- tinygrad/tensor.py | 85 ++++++-------- tinygrad/viz/serve.py | 2 +- 11 files changed, 277 insertions(+), 210 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index be82930b62..c95325349e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -243,8 +243,8 @@ jobs: run: | PYTHONPATH="." python test/external/fuzz_shapetracker.py PYTHONPATH="." python test/external/fuzz_shapetracker_math.py - - name: Repo line count < 11000 lines - run: MAX_LINE_COUNT=11000 python sz.py + - name: Repo line count < 11100 lines + run: MAX_LINE_COUNT=11100 python sz.py testopencl: strategy: diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 78c59bdb18..d2008ef87d 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -11,7 +11,6 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod -from tinygrad.multi import MultiLazyBuffer cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618] cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628] @@ -35,8 +34,6 @@ class UnsyncedBatchNorm: self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False) def __call__(self, x:Tensor): - if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices - xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32) batch_mean, batch_invstd = self.calc_stats(xr) ret = xr.batchnorm( diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b34baced75..25f863568d 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,11 +1,11 @@ import unittest, functools, random from typing import List -from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.ops import Ops +from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable +from tinygrad.ops import Ops, UOp from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule -from tinygrad.multi import all_reduce, MultiLazyBuffer +from tinygrad.multi import all_reduce import numpy as np from hypothesis import given, strategies as strat, settings from tinygrad.device import is_dtype_supported @@ -30,7 +30,7 @@ N = 128 def _test_allreduce(t:Tensor): aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() ts = t.shard(devices_4, 0).realize() - b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0)) + b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0)) b.realize() return aa, b @@ -39,7 +39,7 @@ class TestMultiTensor(unittest.TestCase): def test_to(self): X = Tensor.ones(256).contiguous().realize() X.to_(devices_2) - for lb in X.lazydata.lbs: + for lb in X.lazydata.src: assert lb.shape == (256,) (X + X).realize() @@ -52,7 +52,7 @@ class TestMultiTensor(unittest.TestCase): def test_shard(self): X = Tensor.ones(256).contiguous().realize() X.shard_(devices_2, 0) - for lb in X.lazydata.lbs: + for lb in X.lazydata.src: assert lb.shape == (128,) (X + X).realize() @@ -218,9 +218,9 @@ class TestMultiTensor(unittest.TestCase): shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))]) t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0) with Context(RING=0): - a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) + a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) with Context(RING=2): - b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) + b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) diff = a - b mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy() max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy() @@ -356,8 +356,8 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(m): p.shard_(devices_2).realize() GlobalCounters.reset() shard_output = m(fake_image_sharded).log_softmax().realize() - assert shard_output.lazydata.lbs[0].shape == (1, 1000) - assert shard_output.lazydata.lbs[1].shape == (1, 1000) + assert shard_output.lazydata.src[0].shape == (1, 1000) + assert shard_output.lazydata.src[1].shape == (1, 1000) shard_output_np = shard_output.numpy() np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) @@ -386,12 +386,35 @@ class TestMultiTensor(unittest.TestCase): GlobalCounters.reset() optimizer.zero_grad() shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1) - assert shard_output.lazydata.axis is None shard_output.backward() shard_grad = m.conv1.weight.grad.numpy() # sometimes there is zeros in these grads... why? np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5) + def test_assign_kv_cache_multi(self): + bsz, max_context = 2, 8 + + class Attn: + @TinyJit + def __call__(self, xk:Tensor, start_pos:UOp): + seqlen = xk.shape[1] + if not hasattr(self, "cache_k"): + self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize() + keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk + self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize() + + attn = Attn() + xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous() + attn(xk, 0) + for i in range(3,6): + # copied from LLaMA + start_pos = Variable("start_pos", 1, max_context).bind(i) + xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous() + attn(xk, start_pos) + + out = attn.cache_k.flatten().numpy() + np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.]) + def test_multi_tensor_jit_param(self): @TinyJit def jf(a, b) -> Tensor: @@ -532,13 +555,13 @@ class TestMultiTensor(unittest.TestCase): t4 = t2.reshape((26, 105,)) for t in [t0, t1, t2, t3, t4]: - assert t.lazydata.axis == 1 np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten()) + assert t.lazydata.axis == 1 # test shape-one axis t5 = t4.reshape((26, 1, 105)) - assert t5.lazydata.axis == 2 np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten()) + assert t5.lazydata.axis == 2 # test split and rejoin to the right and reshape to the left t5 = t0.reshape((2, 13, 3, 5, 7)) @@ -553,7 +576,7 @@ class TestMultiTensor(unittest.TestCase): # test no left join with self.assertRaises((AssertionError, ValueError)): - t0.reshape((26*15,7)) + t0.reshape((26*15,7)).schedule() @unittest.skip("no longer supports uneven shard") def test_reshape_on_axis_uneven(self): @@ -588,6 +611,7 @@ class TestMultiTensor(unittest.TestCase): with self.assertRaises(AssertionError): # don't allow assigns that change axes t_none.assign(t_zero) + t_none.schedule() def test_init_rand_with_multiple_devices_fail(self): # init rand with multi device is not allowed @@ -635,7 +659,7 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) - assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.lbs, t2.lazydata.lbs)) + assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src)) def test_rand_like_none_shard(self): t = Tensor.empty((16, 16)).shard(devices_2) @@ -718,7 +742,7 @@ class TestMultiTensor(unittest.TestCase): devices = (d0, d1, d2, d3) t = Tensor.zeros(16, 16).contiguous() t.shard_(devices, axis=0).realize() - assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.lbs]) + assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src]) @unittest.skip("this is unreliable on OSX") def test_clone(self): @@ -774,25 +798,31 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): with self.assertRaises(AssertionError): # sharded axis shrink on non-device boundry is not allowed a = t.shrink(((0, 3), (0, 8))) + a.schedule() with self.assertRaises(AssertionError): # cannot shrink sharded and non-sharded axis at the same time a = t.shrink(((0, 2), (2, 4))) + a.schedule() a = t.shrink(((0, 2), (0, 8))) + a.schedule() assert a.shape == (2, 8) - assert a.lazydata.real == [True, False, False, False] + assert a.lazydata.real == (True, False, False, False) with self.assertRaises(AssertionError): # cannot pad sharded and non-sharded axis at the same time p = a.pad(((0, 6), (0, 1))) + p.schedule() with self.assertRaises(AssertionError): # can only pad to whole axis p = a.pad(((1, 5), (0, 0))) + p.schedule() p = a.pad(((0, 6), (0, 0))) + p.schedule() assert p.shape == (8, 8) - assert p.lazydata.real == [True, True, True, True] + assert p.lazydata.real == (True, True, True, True) @given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16])) def test_ops(self, dtype): @@ -804,8 +834,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a = t.shrink(((0+2*i,2+2*i),None)) b = Tensor(t.numpy()[0+2*i:2+2*i]) assert a.shape == b.shape == (2, 8) - assert a.lazydata.real == [i==j for j in range(4)] np.testing.assert_allclose(a.numpy(), b.numpy()) + assert a.lazydata.real == tuple(i==j for j in range(4)) # cast np.testing.assert_allclose(a.float().numpy(), b.float().numpy()) @@ -865,18 +895,20 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a = t.shrink(((2, 4), None)) b = t.shrink(((6, 8), None)) - self.assertEqual(a.lazydata.real, [False, True, False, False]) - self.assertEqual(b.lazydata.real, [False, False, False, True]) na = t.numpy()[2:4] nb = t.numpy()[6:8] np.testing.assert_equal(a.numpy(), na) np.testing.assert_equal(b.numpy(), nb) + self.assertEqual(a.lazydata.real, (False, True, False, False)) + self.assertEqual(b.lazydata.real, (False, False, False, True)) with self.assertRaises(AssertionError): # cannot add directly c = a + b + c.schedule() c = a.pad(((2, 4), None)) + b.pad(((6, 0), None)) - self.assertEqual(c.lazydata.real, [True, True, True, True]) + c.realize() + self.assertEqual(c.lazydata.real, (True, True, True, True)) expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb]) np.testing.assert_equal(c.numpy(), expected) @@ -937,8 +969,9 @@ class TestBatchNorm(unittest.TestCase): def __call__(self, x:Tensor): bn_ts = [] - for bound, bn in zip(x.lazydata.bounds, self.bns): - xi = x.shrink((bound, None, None, None)) + each = x.shape[0]//len(self.bns) + for i, bn in enumerate(self.bns): + xi = x.shrink(((each*(i), each*(i+1)), None, None, None)) bni = bn(xi) bn_ts.append(bni) return bn_ts[0].cat(*bn_ts[1:]) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index f718ef311e..afb9d726a4 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType -from tinygrad.ops import UOp, Variable, sym_infer +from tinygrad.ops import UOp, Variable, sym_infer, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner @@ -194,7 +194,8 @@ def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] if tensors: Tensor.realize(*tensors) - lbs: list[UOp] = flatten([t.lazydata.lbs for t in tensors]) + # TODO: should we be unpacking multi here? + lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors]) input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None] assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs] diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d8236d47bf..7be4360ac1 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -104,6 +104,7 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}") dtype = buf.dtype.base # ASSIGN already has a target buffer, otherwise we create a new one + assert isinstance(buf.device, str), f"buf device is str, not {buf.device}" buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # track the underlying tensor uop for this buffer @@ -418,7 +419,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) return x.view(unwrap(view.st)) def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if not b.device.startswith("DISK"): return None + if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 756a9c9785..f64f443858 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -36,7 +36,7 @@ pm_gradient = PatternMatcher([ # TODO: this cast can be removed by putting the casts around the EXPAND (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)), - + (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), # there's no gradient for bitcast (UPat(Ops.BITCAST), lambda ctx: (None,)), ]) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index e6e165bb25..8a7b9d04a0 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -1,8 +1,7 @@ from __future__ import annotations import functools, itertools, operator from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv -from tinygrad.dtype import DType -from tinygrad.ops import Ops, MathTrait, UOp, sint +from tinygrad.ops import Ops, UOp, sint def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]: assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}" @@ -40,133 +39,130 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}") return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))] -class MultiLazyBuffer(MathTrait): - def __init__(self, lbs:list[UOp], axis:int|None, real:list[bool]|None=None): - assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them" - assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}" - self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs) +# ***** multi functions ***** - @property - def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)) +from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites - @property - def size(self): return sum(x.size for x in self.real_lbs) +def alu_multi(root:UOp): + msrcs = root.src + assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}" + assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" - @property - def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r] + # NOTE: they all have to share an axis, we always choose [-1] + axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) + srcs:list[list[UOp]] = [] + not_all_real = not all(all(mlb.real) for mlb in msrcs) + new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else msrcs[0].real + assert any(new_real), "output contains no real lb" + for mlb in msrcs: + if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src)) + else: + assert axis is not None and bounds is not None + if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds)) + else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds)) + new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(root.op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} + # NOTE: const dtype should match real + new_dtype = next(iter(new_real_lbs.values())).dtype + new_lbs = [new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))] + return UOp.multi(*new_lbs, axis=axis, real=new_real) - @property - def bounds(self): - if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") - return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0))) +def reduce_multi(root:UOp, multi:UOp): + op, axis = root.arg + if multi.axis is not None and multi.axis in axis: + # all-reduce on sharded axes + reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)] + # if all partitions are real, do all_reduce + if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None) + # only one partition is real, keep it + return UOp.multi(*reduced_parts, axis=None, real=multi.real) + # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct + return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real) - def __repr__(self): return f"" +def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: + return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) - def copy_to_device(self, device:str) -> UOp: - # if we already have a copy on the device, return that - if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device)) - # copy lbs to device, pad to final shape, and sum - llbs:list[UOp] = [] - for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds): - if not real: continue - pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape))) - llbs.append(lb.copy_to_device(device).pad(pad_arg)) - return functools.reduce(operator.add, llbs) +def reshape_multi(root:UOp, multi:UOp): + arg = root.arg + if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real) + assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" + arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) + # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards + # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? + new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1 + assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \ + f"reshape cannot move items between shards {multi.shape} -> {root.arg=}" + lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src] + return UOp.multi(*lbs, axis=new_axis, real=multi.real) - # passthroughs - @property - def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs) - def cast(self, dtype:DType): return MultiLazyBuffer([x.cast(dtype) for x in self.lbs], self.axis, self.real) - def bitcast(self, dtype:DType): return MultiLazyBuffer([x.bitcast(dtype) for x in self.lbs], self.axis, self.real) - def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real) - def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real) - def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real) - def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real) - def detach(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.detach() for lb in self.lbs], self.axis, self.real) - @property - def toposort(self) -> dict[UOp, None]: return {l:None for x in self.lbs for l in x.toposort} +def expand_multi(root:UOp, multi:UOp): + # NOTE: this assert isn't needed, sharded axis can have dim 1 + assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}" + return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real) - # elementwise is simple - def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer: - msrcs = (self,)+in_srcs - assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}" - assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" +def pad_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}" + # pad on shard axis -> fill others with zeros and set real to all True + if multi.axis is not None and root.arg[multi.axis] != (0,0): + # pad back to whole axis, remove real mask + assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time" + dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)] + assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" + return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis) + return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) - # NOTE: they all have to share an axis, we always choose [-1] - axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) - srcs:list[list[UOp]] = [] - not_all_real = not all(all(mlb.real) for mlb in msrcs) - new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real - assert any(new_real), "output contains no real lb" - for mlb in msrcs: - if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs) - else: - assert axis is not None and bounds is not None - if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds)) - else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds)) - new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} - # NOTE: const dtype should match real - new_dtype = next(iter(new_real_lbs.values())).dtype - return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real) +def permute_multi(root:UOp, multi:UOp): + # all permutes supported! + return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real) - def r(self, op:Ops, axis:tuple[int, ...]) -> MultiLazyBuffer: - if self.axis is not None and self.axis in axis: - # all-reduce on sharded axes - reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)] - # if all partitions are real, do all_reduce - if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) - # only one partition is real, keep it - return MultiLazyBuffer(reduced_parts, None, self.real) - # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real) +def shrink_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \ + f"shrinking not supported for {root.arg=}" + if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]): + assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \ + "cannot shrink sharded and non-sharded axis at the same time" + # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real + idx = multi.bounds.index(root.arg[multi.axis]) + # zero out other lbs to not create lb reference + return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)], + axis=multi.axis, real=[i==idx for i in range(len(multi.src))]) + return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src], + axis=multi.axis, real=multi.real) - # *** movement ops *** +def stride_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis" + return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) - def _shape_to_single_shard(self, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: - return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) +def copy_multi(multi:UOp, device:UOp): + # if we already have a copy on the device, return that + if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg)) + # copy lbs to device, pad to final shape, and sum + llbs:list[UOp] = [] + for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds): + if not real: continue + pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape))) + llbs.append(lb.copy_to_device(device.arg).pad(pad_arg)) + return functools.reduce(operator.add, llbs) - def reshape(self, arg:tuple[sint, ...]): - if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real) - assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)" - arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) - # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards - # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? - new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1 - assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}" - lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs] - return MultiLazyBuffer(lbs, new_axis, self.real) +def assign_multi(dest:UOp, src:UOp): + assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}" + return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real) - def pad(self, arg:tuple[tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}" - # pad on shard axis -> fill others with zeros and set real to all True - if self.axis is not None and arg[self.axis] != (0,0): - # pad back to whole axis, remove real mask - assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time" - dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)] - assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" - return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis) - return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real) +def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real) - def expand(self, arg:tuple[sint, ...]): - # NOTE: this assert isn't needed, sharded axis can have dim 1 - assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}" - return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real) +# NOTE: this is the same pattern as Ops.UNROLL +multi_pm = PatternMatcher([ + (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi), + (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi), + (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi), + (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi), + (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi), + (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi), + (UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi), + (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), + (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi), + (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), +]) - def permute(self, arg:tuple[int, ...]): - # all permutes supported! - return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real) - - def shrink(self, arg:tuple[tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}" - if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]): - assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time" - # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real - idx = self.bounds.index(arg[self.axis]) - # zero out other lbs to not create lb reference - return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))]) - return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs], - self.axis, self.real) - - def stride(self, arg:tuple[int, ...]): - assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" - return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real) +@track_rewrites(named=True) +def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v} diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 9400dd2e82..99032cfe63 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -5,7 +5,6 @@ from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T from tinygrad.shape.view import strides_for_shape -from tinygrad.multi import MultiLazyBuffer class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): @@ -152,9 +151,9 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr continue if v.shape != state_dict[k].shape: raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') - if isinstance((mlb:=v.lazydata), MultiLazyBuffer): - if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize() - else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize() + if isinstance(v.device, tuple): + if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize() + else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize() else: v.replace(state_dict[k].to(v.device)).realize() if consume: del state_dict[k] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0225603cc4..cf40c1fbed 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -150,6 +150,7 @@ class Ops(FastEnum): # device DEVICE = auto() + MULTI = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} @@ -281,6 +282,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def st(self) -> ShapeTracker|None: + from tinygrad.shape.shapetracker import ShapeTracker + if self.op is Ops.MULTI: + return ShapeTracker.from_shape( + tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))) # these ops define a ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) @@ -294,7 +299,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # only reduce ops are allowed to change shape, everything else derives shape from sources elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg) else: shape = src_sts[0].shape - from tinygrad.shape.shapetracker import ShapeTracker return ShapeTracker.from_shape(shape) @functools.cached_property @@ -350,7 +354,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source - return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) + if self._device is None: return UOp.const(self.dtype, b) + if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None) + return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self @@ -389,7 +395,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): new_shape = unwrap(self.st).reduce(axis) # TODO: can we split symbolic shape if the reduce axis is not symbolic? - if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \ + # TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi + if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, axis) @@ -410,6 +417,45 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def contiguous(self): return self.alu(Ops.CONTIGUOUS) + # *** from MultiLazyBuffer *** + + def multi(self, *more:UOp, axis:int|None, real:list[bool]|None=None): + parents = (self,)+more + assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype" + return UOp(Ops.MULTI, self.dtype, parents, (axis, tuple(real if real is not None else [True]*len(parents)))) + + @property + def bounds(self): + if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") + return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0))) + + @property + def axis(self): + assert self.op is Ops.MULTI + return self.arg[0] + + @property + def real(self): + assert self.op is Ops.MULTI + return self.arg[1] + + @property + def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r] + + def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp: + if axis is None: lbs = [self] * len(devices) + else: + if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}") + sz = self.shape[axis] // len(devices) + sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))] + lbs, off = [], 0 + for sz in sizes: + lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape)))) + off += sz + sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)] + # NOTE: this contiguous is making it impossible for the scheduler to do late const folding + return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis) + # *** from LazyBuffer *** @staticmethod @@ -426,7 +472,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val) # otherwise it's just a VIEW(BUFFER) return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st) - def copy_to_device(self, device:str, clone:bool=False) -> UOp: + def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp: # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st) @@ -440,8 +486,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) @property - def lbs(self): return [self] - @property def metadata(self): return all_metadata.get(self, None) # *** uop movement ops *** @@ -470,10 +514,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @staticmethod def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) @property - def device(self) -> str: return unwrap(self._device) + def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device)) @functools.cached_property - def _device(self) -> Optional[str]: + def _device(self) -> Optional[str|tuple[str, ...]]: if self.op is Ops.DEVICE: return self.arg + if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src) return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: @@ -489,6 +534,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" if (cret:=buffers.get(self)) is not None: return cret from tinygrad.device import Buffer + assert isinstance(self.device, str), f"buffer not supported on multi {self.device}" buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base) return ret @property @@ -496,7 +542,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized return self.buffer if self.op is Ops.BUFFER else None @property - def is_realized(self) -> bool: return self.base.realized is not None + def is_realized(self) -> bool: + return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None # *** uop Variable stuff *** @@ -639,7 +686,7 @@ def print_uops(uops:list[UOp]): def get_location() -> tuple[str, int]: frm = sys._getframe(1) # find the real frame in the file that has the UPat, TODO: is there a better way to do this? - while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", + while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py", "lowerer.py", "cstyle.py", "linearize.py"}: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 45a9b23749..2c47162c26 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -6,10 +6,10 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap -from tinygrad.multi import MultiLazyBuffer +from tinygrad.multi import get_multi_map from tinygrad.gradient import compute_gradient from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element -from tinygrad.device import Device, Buffer, BufferSpec +from tinygrad.device import Device, BufferSpec from tinygrad.engine.realize import run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars @@ -30,18 +30,17 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: # link the found UOps back to Tensors. exit early if there's no Tensors to realize # NOTE: this uses all_tensors, but it's fast - fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)] + fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops] if len(fixed_tensors): # potentially rewrite all the discovered Tensors - sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors]) + sink = UOp.sink(*[t.lazydata for t in fixed_tensors]) new_sink = sink.substitute(applied_map) # set the relevant lazydata to the realized UOps for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src): if s is ns: continue - if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src) - else: t.lazydata = ns + t.lazydata = ns # **** start with two base classes, Tensor and Function **** @@ -68,7 +67,7 @@ import tinygrad.function as F def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None): if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) - return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg) for d in device], None) + return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None) def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 import numpy as np @@ -159,7 +158,7 @@ class Tensor(SimpleMathTrait): return instance def __del__(self): all_tensors.discard(weakref.ref(self)) - def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 + def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None): if dtype is not None: dtype = to_dtype(dtype) if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None @@ -176,7 +175,7 @@ class Tensor(SimpleMathTrait): self._ctx: Optional[Function] = None # create a LazyBuffer from the different types of inputs - if isinstance(data, (UOp, MultiLazyBuffer)): + if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" # NOTE: this is here because LazyBuffer = UOp if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data) @@ -199,12 +198,12 @@ class Tensor(SimpleMathTrait): data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") # by this point, it has to be a LazyBuffer - if not isinstance(data, (UOp, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") + if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") # data might be on a different device - if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device) + if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device) # if device is a tuple, we should have/construct a MultiLazyBuffer - elif isinstance(data, UOp): self.lazydata = Tensor(data).shard(device).lazydata + elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata else: assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" self.lazydata = data @@ -224,8 +223,8 @@ class Tensor(SimpleMathTrait): def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev def __repr__(self): - if isinstance(ld:=self.lazydata, MultiLazyBuffer): ld_repr = f"{ld!r}" - else: ld_repr = f"" + ld = self.lazydata + ld_repr = f"" return f"" # Python has a non moving GC, so this should be okay @@ -254,7 +253,14 @@ class Tensor(SimpleMathTrait): NOTE: A Tensor can only be scheduled once. """ - big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst])) + big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) + + # TODO: move this to scheduler tensor_map pass + if any(x.op is Ops.MULTI for x in big_sink.toposort): + # multi fixup + _apply_map_to_tensors(get_multi_map(big_sink)) + big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst])) + schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) _apply_map_to_tensors(becomes_map) return memory_planner(schedule), var_vals @@ -293,7 +299,6 @@ class Tensor(SimpleMathTrait): assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" - assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer" assert not x.requires_grad # self requires_grad is okay? if not self.lazydata.is_realized: return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) @@ -309,7 +314,8 @@ class Tensor(SimpleMathTrait): if 0 in self.shape: return memoryview(bytearray(0)) # NOTE: this realizes on the object from as_buffer being a Python object cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize() - buf = cast(Buffer, cast(UOp, cpu.lazydata).base.realized) + buf = cast(UOp, cpu.lazydata).base.realized + assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized" if self.device != "CLANG": buf.options = BufferSpec(nolru=True) return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False) @@ -405,18 +411,9 @@ class Tensor(SimpleMathTrait): print(t.shard((t.device, t.device), axis=1).lazydata) ``` """ - assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer" + assert isinstance(self.device, str), "can't shard a MultiLazyBuffer" devices = tuple(Device.canonicalize(x) for x in devices) - if axis is None: lbs = [self.lazydata] * len(devices) - else: - axis = self._resolve_dim(axis) - if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}") - sz = self.shape[axis] // len(devices) - sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))] - lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)] - sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)] - # NOTE: this contiguous is making it impossible for the scheduler to do late const folding - mlb = MultiLazyBuffer([lb.contiguous() for lb in sharded_lbs], axis) + mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None): @@ -439,7 +436,7 @@ class Tensor(SimpleMathTrait): def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs): dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): - return Tensor(MultiLazyBuffer([UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None), + return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None), device, dtype, **kwargs) return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs) @@ -750,12 +747,12 @@ class Tensor(SimpleMathTrait): ``` """ dtype = kwargs.pop("dtype", self.dtype) - if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer): + if isinstance(self.device, tuple): if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) contiguous = kwargs.pop("contiguous", True) - rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs] - return Tensor(MultiLazyBuffer(cast(list[UOp], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) + rands = [Tensor.rand(*lb.shape, device=cast(str, lb.device), dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.src] + return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs) # ***** rng hlops ***** @@ -921,18 +918,15 @@ class Tensor(SimpleMathTrait): assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) rets = [] - for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)): - target_uops = [x.lazydata.lbs[i] for x in targets] - grads = compute_gradient(uop, grad, set(target_uops)) - ret = [] - for x in target_uops: - if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}") - ret.append(y) - rets.append(ret) + target_uops = [x.lazydata for x in targets] + grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops)) + ret = [] + for x in target_uops: + if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") + ret.append(y) + rets.append(ret) # create returned Tensors - if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] - return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real), - device=t.device) for t,u in zip(targets, zip(*rets))] + return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] def _deepwalk(self) -> list[Tensor]: def _walk(node:Tensor, visited:set[Tensor]): @@ -977,8 +971,7 @@ class Tensor(SimpleMathTrait): for t, g in zip(ctx.parents, grads): if g is not None and t.requires_grad: assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" - assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \ - f"grad uop must have a path from self\ngrad uop: {t.lazydata}" + assert t.lazydata in toposort_uop, f"grad uop must have a path from self\ngrad uop: {t.lazydata}" t.grad = g if t.grad is None else (t.grad + g) if not retain_graph: del t0._ctx return self @@ -1278,7 +1271,7 @@ class Tensor(SimpleMathTrait): self._getitem(indices).assign(v) return # NOTE: check that setitem target is valid first - if not all(unwrap(lb.st).contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous") + if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index c95fb5f800..fdceca9c0b 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -13,7 +13,7 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}