mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
new test_multitensor tests (#10667)
* new test_multitensor tests * cleanup scheduler
This commit is contained in:
@@ -1173,13 +1173,37 @@ class TestMultiAssign(unittest.TestCase):
|
||||
out[:, 2:3].assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
|
||||
|
||||
def test_multi_assign_var_offset(self):
|
||||
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
|
||||
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
|
||||
vi = Variable("i", 0, 3).bind(2)
|
||||
out[:, vi:vi+1].assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
|
||||
|
||||
def test_multi_assign_var_offset_jit_none(self): self.test_multi_assign_var_offset_jit(None)
|
||||
def test_multi_assign_var_offset_jit(self, shard_axis=0):
|
||||
out = Tensor.zeros(4,6).contiguous().realize().shard(self.device, shard_axis).realize()
|
||||
ones = Tensor.ones(4,1).shard(self.device, shard_axis).contiguous().realize()
|
||||
|
||||
@TinyJit
|
||||
def f(out:Tensor, vi):
|
||||
out[:, vi:vi+1].assign(ones).realize()
|
||||
ones.assign(ones+1).realize()
|
||||
|
||||
vi = Variable("i", 0, 5)
|
||||
for i in range(1,5):
|
||||
GlobalCounters.reset()
|
||||
f(out, vi.bind(i))
|
||||
self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiTransformer(unittest.TestCase):
|
||||
def test_transformer(self):
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||
|
||||
from extra.models.llama import Transformer
|
||||
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
|
||||
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024,
|
||||
"hidden_dim": 64, "max_context": 12}
|
||||
real_model = Transformer(**args)
|
||||
shard_model = Transformer(**args)
|
||||
|
||||
@@ -1199,10 +1223,18 @@ class TestMultiTransformer(unittest.TestCase):
|
||||
|
||||
last_tok = 0
|
||||
for i in range(10):
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
|
||||
last_tok = real_tok.item()
|
||||
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i).item()
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), i).item()
|
||||
|
||||
# test kv cache
|
||||
kv1 = real_model.layers[0].attention.cache_kv.numpy()
|
||||
kv2 = shard_model.layers[0].attention.cache_kv.numpy()
|
||||
#print(np.concatenate([kv1[:, :, :, :, 0:1], kv2[:, :, :, :, 0:1]], axis=4))
|
||||
np.testing.assert_allclose(kv1, kv2, atol=1e-5, rtol=1e-5, err_msg=f"issue at token {i}")
|
||||
|
||||
# test token
|
||||
self.assertEqual(real_tok, shard_tok, f"issue at token {i}")
|
||||
last_tok = real_tok
|
||||
|
||||
@unittest.skip("super slow")
|
||||
def test_llama1b_full(self):
|
||||
|
||||
@@ -50,14 +50,6 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
||||
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
||||
return base.copy_to_device(copy.device).view(view.arg)
|
||||
|
||||
def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
|
||||
st = unwrap(view.st)
|
||||
# replace dnum in ShapeTracker with literal const for this mselect
|
||||
if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']):
|
||||
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
|
||||
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
|
||||
return base.mselect(ms.arg).view(st)
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
@@ -79,8 +71,6 @@ sym = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
|
||||
# store a shrink before COPY, otherwise view after the COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
|
||||
# MSELECT must select a base, if there are views apply them after selecting the base
|
||||
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
@@ -570,6 +560,7 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if u.op is not Ops.ASSIGN: continue
|
||||
kernel_assign[u.buf_uop] = u
|
||||
for s in u.src[1].src:
|
||||
# TODO: this is probably broken for MSELECT/MSTACK
|
||||
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
||||
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
|
||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import cast
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp
|
||||
|
||||
# *** allreduce implementation ***
|
||||
@@ -53,7 +53,21 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
pads = [((s,numel-e),) for s,e in chunks]
|
||||
return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape)
|
||||
|
||||
replace_allreduce = PatternMatcher([(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),])
|
||||
# ***** multi rewrite MSELECT/MSTACK *****
|
||||
|
||||
def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
|
||||
st = unwrap(view.st)
|
||||
# replace dnum in ShapeTracker with literal const for this mselect
|
||||
if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']):
|
||||
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
|
||||
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
|
||||
return base.mselect(ms.arg).view(st)
|
||||
|
||||
replace_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
# MSELECT must select a base, if there are views apply them after selecting the base
|
||||
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view),
|
||||
])
|
||||
|
||||
# ***** multi functions *****
|
||||
|
||||
|
||||
@@ -77,28 +77,23 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
if ast.op is Ops.COPY:
|
||||
assert ast.arg is None, "copy arg is no longer supported"
|
||||
if isinstance(ubufs[1], MultiBuffer): # src is multiple buffers, none selected
|
||||
if isinstance(ubufs[0], MultiBuffer):
|
||||
# COPY ALL -> ALL
|
||||
assert len(ubufs[0].bufs) == len(ubufs[1].bufs), "all to all copy must have matching buffer length"
|
||||
for b1,b2 in zip(ubufs[0].bufs, ubufs[1].bufs): schedule.append(ScheduleItem(ast, (b1, b2), k.arg.metadata))
|
||||
else:
|
||||
# COPY ANY -> ONE. Currently we just select the first
|
||||
schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata))
|
||||
if ast.op is Ops.COPY and (isinstance(ubufs[0], Buffer) or isinstance(ubufs[1], Buffer)):
|
||||
if isinstance(ubufs[1], MultiBuffer) and isinstance(ubufs[0], Buffer): # src is multiple buffers, none selected
|
||||
# COPY ANY -> ONE. Currently we just select the first
|
||||
schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata))
|
||||
elif isinstance(ubufs[0], MultiBuffer) and isinstance(ubufs[1], Buffer):
|
||||
# COPY ONE -> ALL (BROADCAST)
|
||||
for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, ubufs[1]), k.arg.metadata))
|
||||
else:
|
||||
assert isinstance(ubufs[1], Buffer), "src can't be MultiBuffer"
|
||||
if isinstance(ubufs[0], MultiBuffer):
|
||||
# COPY ONE -> ALL (BROADCAST)
|
||||
for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, ubufs[1]), k.arg.metadata))
|
||||
else: schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1]), k.arg.metadata)) # COPY ONE -> ONE
|
||||
raise RuntimeError("unsupported copy type")
|
||||
else:
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
# ALL -> ALL
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
|
||||
for x in children[k]:
|
||||
in_degree[x] -= 1
|
||||
|
||||
Reference in New Issue
Block a user