new test_multitensor tests (#10667)

* new test_multitensor tests

* cleanup scheduler
This commit is contained in:
George Hotz
2025-06-06 10:26:28 -07:00
committed by GitHub
parent 5170f387b3
commit 7f0f97aa76
4 changed files with 64 additions and 32 deletions

View File

@@ -1173,13 +1173,37 @@ class TestMultiAssign(unittest.TestCase):
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_var_offset(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
vi = Variable("i", 0, 3).bind(2)
out[:, vi:vi+1].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_var_offset_jit_none(self): self.test_multi_assign_var_offset_jit(None)
def test_multi_assign_var_offset_jit(self, shard_axis=0):
out = Tensor.zeros(4,6).contiguous().realize().shard(self.device, shard_axis).realize()
ones = Tensor.ones(4,1).shard(self.device, shard_axis).contiguous().realize()
@TinyJit
def f(out:Tensor, vi):
out[:, vi:vi+1].assign(ones).realize()
ones.assign(ones+1).realize()
vi = Variable("i", 0, 5)
for i in range(1,5):
GlobalCounters.reset()
f(out, vi.bind(i))
self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4)
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiTransformer(unittest.TestCase):
def test_transformer(self):
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
from extra.models.llama import Transformer
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024,
"hidden_dim": 64, "max_context": 12}
real_model = Transformer(**args)
shard_model = Transformer(**args)
@@ -1199,10 +1223,18 @@ class TestMultiTransformer(unittest.TestCase):
last_tok = 0
for i in range(10):
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
last_tok = real_tok.item()
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i).item()
shard_tok = shard_model(Tensor([[last_tok]], device=device), i).item()
# test kv cache
kv1 = real_model.layers[0].attention.cache_kv.numpy()
kv2 = shard_model.layers[0].attention.cache_kv.numpy()
#print(np.concatenate([kv1[:, :, :, :, 0:1], kv2[:, :, :, :, 0:1]], axis=4))
np.testing.assert_allclose(kv1, kv2, atol=1e-5, rtol=1e-5, err_msg=f"issue at token {i}")
# test token
self.assertEqual(real_tok, shard_tok, f"issue at token {i}")
last_tok = real_tok
@unittest.skip("super slow")
def test_llama1b_full(self):

View File

@@ -50,14 +50,6 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
st = unwrap(view.st)
# replace dnum in ShapeTracker with literal const for this mselect
if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']):
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
return base.mselect(ms.arg).view(st)
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
sym = symbolic_simple+PatternMatcher([
@@ -79,8 +71,6 @@ sym = symbolic_simple+PatternMatcher([
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
# MSELECT must select a base, if there are views apply them after selecting the base
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
@@ -570,6 +560,7 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
if u.op is not Ops.ASSIGN: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
# TODO: this is probably broken for MSELECT/MSTACK
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")

View File

@@ -1,6 +1,6 @@
from typing import cast
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp
# *** allreduce implementation ***
@@ -53,7 +53,21 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
pads = [((s,numel-e),) for s,e in chunks]
return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape)
replace_allreduce = PatternMatcher([(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),])
# ***** multi rewrite MSELECT/MSTACK *****
def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
st = unwrap(view.st)
# replace dnum in ShapeTracker with literal const for this mselect
if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']):
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
return base.mselect(ms.arg).view(st)
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# MSELECT must select a base, if there are views apply them after selecting the base
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view),
])
# ***** multi functions *****

View File

@@ -77,28 +77,23 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src)
if any(isinstance(x, MultiBuffer) for x in ubufs):
if ast.op is Ops.COPY:
assert ast.arg is None, "copy arg is no longer supported"
if isinstance(ubufs[1], MultiBuffer): # src is multiple buffers, none selected
if isinstance(ubufs[0], MultiBuffer):
# COPY ALL -> ALL
assert len(ubufs[0].bufs) == len(ubufs[1].bufs), "all to all copy must have matching buffer length"
for b1,b2 in zip(ubufs[0].bufs, ubufs[1].bufs): schedule.append(ScheduleItem(ast, (b1, b2), k.arg.metadata))
else:
# COPY ANY -> ONE. Currently we just select the first
schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata))
if ast.op is Ops.COPY and (isinstance(ubufs[0], Buffer) or isinstance(ubufs[1], Buffer)):
if isinstance(ubufs[1], MultiBuffer) and isinstance(ubufs[0], Buffer): # src is multiple buffers, none selected
# COPY ANY -> ONE. Currently we just select the first
schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata))
elif isinstance(ubufs[0], MultiBuffer) and isinstance(ubufs[1], Buffer):
# COPY ONE -> ALL (BROADCAST)
for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, ubufs[1]), k.arg.metadata))
else:
assert isinstance(ubufs[1], Buffer), "src can't be MultiBuffer"
if isinstance(ubufs[0], MultiBuffer):
# COPY ONE -> ALL (BROADCAST)
for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, ubufs[1]), k.arg.metadata))
else: schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1]), k.arg.metadata)) # COPY ONE -> ONE
raise RuntimeError("unsupported copy type")
else:
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
# ALL -> ALL
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
else:
# ONE -> ONE
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
for x in children[k]:
in_degree[x] -= 1