mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete ReduceOps, only use REDUCE_AXIS (#7667)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, Ops
|
||||
from tinygrad.ops import MetaOps, BinaryOps, Ops
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -31,7 +31,7 @@ N = 128
|
||||
def _test_allreduce(t:Tensor):
|
||||
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
|
||||
ts = t.shard(devices_4, 0).realize()
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0))
|
||||
b.realize()
|
||||
return aa, b
|
||||
|
||||
@@ -145,14 +145,14 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
|
||||
|
||||
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
|
||||
strat.sampled_from((ReduceOps.SUM, ReduceOps.PROD, ReduceOps.REDUCE_MAX)),
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
|
||||
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
|
||||
n = X.numpy()
|
||||
X.shard_(devices, shard_axis)
|
||||
f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.PROD: lambda x: x.prod(reduce_axis),
|
||||
ReduceOps.REDUCE_MAX: lambda x: x.max(reduce_axis)}[rop]
|
||||
f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis),
|
||||
Ops.MAX: lambda x: x.max(reduce_axis)}[rop]
|
||||
fX = f(X)
|
||||
fn = f(n)
|
||||
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
||||
@@ -197,9 +197,9 @@ class TestMultiTensor(unittest.TestCase):
|
||||
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
|
||||
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
|
||||
with Context(RING=0):
|
||||
a = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
|
||||
a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
|
||||
with Context(RING=2):
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
|
||||
diff = a - b
|
||||
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
|
||||
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
|
||||
|
||||
Reference in New Issue
Block a user