multi is O(1) (#10183)

* multi is O(1)

* allreduce

* no new uops needed

* junk

* something

* simple

* that's really what i want

* closer

* inject _device_num

* pretty print

* cleanups

* this

* early dnum

* ops allreduce is good

* ish

* device is the tuple and this is fine

* simpler

* progress

* copy_multi

* work

* more tests

* more tests pass

* work

* no None axis

* tests

* no none multi

* type fixes

* pre commit passes

* lil

* remove this

* mlperf dataloader on mac

* that test was wrong

* unbind

* support DEBUG=2

* realize

* only unbind bound vars

* don't include fixedvars

* graph test

* one test

* fixedvars in hcq

* new ring reduce

* ring reduce

* simpler ring

* mselect

* mselect doesn't work

* Revert "mselect doesn't work"

This reverts commit c78b77bd7d.

* Revert "mselect"

This reverts commit bb2e430ac3.

* simpler

* fixups

* no optional

* fix jit

* move things around

* cleanup multi

* simpler multi

* simpler reshape
This commit is contained in:
George Hotz
2025-05-16 23:14:23 -07:00
committed by GitHub
parent e1a40e8040
commit e13f2a3092
9 changed files with 157 additions and 122 deletions

View File

@@ -5,7 +5,6 @@ from tinygrad.ops import Ops, UOp
from tinygrad.helpers import CI, getenv, prod, Context, OSX
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
from tinygrad.engine.multi import all_reduce
import numpy as np
from hypothesis import given, strategies as strat, settings
from tinygrad.device import is_dtype_supported
@@ -31,7 +30,7 @@ N = 128
def _test_allreduce(t:Tensor):
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
ts = t.shard(devices_4, 0).realize()
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0))
b = Tensor(UOp.allreduce(ts.lazydata, Ops.ADD, ts.device))
b.realize()
return aa, b
@@ -84,7 +83,7 @@ class TestMultiTensor(unittest.TestCase):
for si, ei in lower_schedule(sched):
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
ei.run()
self.assertEqual(len(set(names)), 2), "function was relinearized"
self.assertEqual(len(set(names)), 1), "function was relinearized"
@unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs")
def test_sharded_memory(self):
@@ -226,17 +225,16 @@ class TestMultiTensor(unittest.TestCase):
out = f(tt)
assert out.item() == 1+2+3+4
@unittest.skip("slow")
def test_fuzz_allreduce(self):
random.seed(41)
for it in range(100):
for it in range(2):
for n in range(2, 4+1):
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
with Context(RING=0):
a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
a = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device))
with Context(RING=2):
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
b = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()