mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
||||
sudo apt update || true
|
||||
sudo apt install -y --no-install-recommends ninja-build
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND
|
||||
run: SPLIT_REDUCEOP=0 FUSE_ARANGE=1 CPU=1 CPU_LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
run: CPU=1 CPU_LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
- name: Test some torch tests (expect failure)
|
||||
run: python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
@@ -533,7 +533,7 @@ jobs:
|
||||
- name: Test LLVM=1 DEVECTORIZE=0 for model
|
||||
run: CPU=1 CPU_LLVM=1 DEVECTORIZE=0 python3 test/models/test_efficientnet.py
|
||||
- name: Test CPU=1 DEVECTORIZE=0
|
||||
run: CPU=1 CPU_LLVM=0 DEVECTORIZE=0 FUSE_ARANGE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
run: CPU=1 CPU_LLVM=0 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
|
||||
testdsp:
|
||||
name: Linux (DSP)
|
||||
|
||||
@@ -10,7 +10,7 @@ GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
|
||||
|
||||
# override tinygrad defaults
|
||||
dtypes.default_float = dtypes.half
|
||||
Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__()
|
||||
Context(FUSE_OPTIM=1).__enter__()
|
||||
|
||||
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
|
||||
batchsize = getenv("BS", 1024)
|
||||
|
||||
@@ -145,7 +145,6 @@ hyp = {
|
||||
},
|
||||
}
|
||||
|
||||
@Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1))
|
||||
def train_cifar():
|
||||
|
||||
def set_seed(seed):
|
||||
|
||||
@@ -1309,7 +1309,7 @@ def train_llama3():
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
|
||||
EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
|
||||
|
||||
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# trains to 7
|
||||
|
||||
opt_adamw_beta_1 = 0.9
|
||||
|
||||
@@ -7,7 +7,6 @@ bert_train_params = {
|
||||
"GPUS": 6,
|
||||
"BS": 96,
|
||||
"EVAL_BS": 96,
|
||||
"FUSE_ARANGE": 1,
|
||||
"BASEDIR": "/raid/datasets/wiki",
|
||||
}
|
||||
|
||||
|
||||
@@ -227,16 +227,15 @@ class TestTorchBackend(unittest.TestCase):
|
||||
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
|
||||
|
||||
def test_mnist_index(self):
|
||||
with Context(FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
GlobalCounters.reset()
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
X_train = torch.tensor(X_train.float().numpy(), device=device)
|
||||
Y_train = torch.tensor(Y_train.cast('int64').numpy(), device=device)
|
||||
samples = torch.randint(0, X_train.shape[0], (32,))
|
||||
X,Y = X_train[samples], Y_train[samples]
|
||||
X.cpu(), Y.cpu()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, 10_000_000)
|
||||
GlobalCounters.reset()
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
X_train = torch.tensor(X_train.float().numpy(), device=device)
|
||||
Y_train = torch.tensor(Y_train.cast('int64').numpy(), device=device)
|
||||
samples = torch.randint(0, X_train.shape[0], (32,))
|
||||
X,Y = X_train[samples], Y_train[samples]
|
||||
X.cpu(), Y.cpu()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, 10_000_000)
|
||||
|
||||
def _test_diagonal(self, *shape):
|
||||
a = torch.randn(*shape, dtype=torch.float32, device=device)
|
||||
|
||||
@@ -25,22 +25,6 @@ class TestArange(unittest.TestCase):
|
||||
t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3])
|
||||
self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4])
|
||||
|
||||
class TestRand(unittest.TestCase):
|
||||
def test_fused_rand_less_ops(self, noopt=1):
|
||||
GlobalCounters.reset()
|
||||
with Context(FUSE_ARANGE=0, NOOPT=noopt):
|
||||
out = Tensor.rand(16384)
|
||||
out.realize()
|
||||
unfused_ops = GlobalCounters.global_ops
|
||||
|
||||
GlobalCounters.reset()
|
||||
with Context(FUSE_ARANGE=1, NOOPT=noopt):
|
||||
out = Tensor.rand(16384)
|
||||
out.realize()
|
||||
print(f"fused {GlobalCounters.global_ops} unfused {unfused_ops}")
|
||||
self.assertLessEqual(GlobalCounters.global_ops, unfused_ops*2)
|
||||
def test_fused_rand_less_ops_opt(self): self.test_fused_rand_less_ops(0)
|
||||
|
||||
DSET, DDIM = 2048, 32
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
@@ -48,7 +32,7 @@ class TestIndexing(unittest.TestCase):
|
||||
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
||||
needle[1337] = 1
|
||||
needle.realize()
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
with Context(NOOPT=1):
|
||||
GlobalCounters.reset()
|
||||
out = ((Tensor.arange(1,16385)-1)*needle).sum()
|
||||
sched = out.schedule()
|
||||
@@ -61,7 +45,7 @@ class TestIndexing(unittest.TestCase):
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
real_index = dataset.numpy()[idxs.numpy()]
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
with Context(NOOPT=1):
|
||||
GlobalCounters.reset()
|
||||
rng = Tensor.ones(4, DDIM, DSET, dtype=dtypes.int)._cumalu(axis=-1, op=Ops.ADD, _include_initial=True).reshape(4, DDIM, DSET, 1)
|
||||
idxs = idxs.reshape(4,1,1,1).expand(4, DDIM, DSET, 1)
|
||||
@@ -77,7 +61,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_index_variable(self):
|
||||
dataset = Tensor.rand(DSET, DDIM).realize()
|
||||
v = Variable("v", 0, DDIM-1)
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
with Context(NOOPT=1):
|
||||
GlobalCounters.reset()
|
||||
vb = Tensor(v.bind(12))
|
||||
comp = dataset[vb].numpy()
|
||||
@@ -106,7 +90,7 @@ class TestIndexing(unittest.TestCase):
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
real_index = dataset.numpy()[idxs.numpy()]
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1):
|
||||
with Context(NOOPT=noopt):
|
||||
GlobalCounters.reset()
|
||||
X = dataset[idxs]
|
||||
assert X.shape == (4,DDIM)
|
||||
@@ -121,7 +105,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_index_fused_out_of_bounds(self):
|
||||
dataset = Tensor.rand(256, 256).realize()
|
||||
idxs = Tensor([-19238, -257, 256, 495, 10982377]).realize()
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
with Context(NOOPT=1):
|
||||
X = dataset[idxs]
|
||||
np.testing.assert_equal(X.numpy(), 0)
|
||||
|
||||
@@ -130,7 +114,7 @@ class TestIndexing(unittest.TestCase):
|
||||
if Device.DEFAULT == "WEBGPU": op_limit *= 15
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=split_reduceop):
|
||||
with Context(NOOPT=noopt, SPLIT_REDUCEOP=split_reduceop):
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize()
|
||||
GlobalCounters.reset()
|
||||
x = X_train[samples].numpy()
|
||||
@@ -150,7 +134,7 @@ class TestIndexing(unittest.TestCase):
|
||||
# TODO: why is a new realize needed here
|
||||
emb_w = emb.weight.realize().numpy()
|
||||
x = Tensor([1,2,3,4])
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1):
|
||||
with Context(NOOPT=noopt):
|
||||
GlobalCounters.reset()
|
||||
z = emb(x).realize()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
|
||||
|
||||
@@ -447,11 +447,11 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
# TODO: fused with opts uses more ops
|
||||
def test_embedding_one_kernel_fused(self):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=0):
|
||||
with Context(NOOPT=0):
|
||||
self.test_embedding_one_kernel(ops=612_000, kcount=2)
|
||||
|
||||
def test_embedding_one_kernel_fused_noopt(self):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=1):
|
||||
with Context(NOOPT=1):
|
||||
self.test_embedding_one_kernel(ops=0, kcount=2)
|
||||
|
||||
def test_embedding_shape(self):
|
||||
@@ -465,10 +465,9 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
def test_embedding_regression(self):
|
||||
# used to fail bounds check
|
||||
with Context(FUSE_ARANGE=1):
|
||||
embedding = Embedding(100, 1024)
|
||||
input_ids = Tensor.empty(16, 16, dtype=dtypes.int)
|
||||
embedding(input_ids).realize()
|
||||
embedding = Embedding(100, 1024)
|
||||
input_ids = Tensor.empty(16, 16, dtype=dtypes.int)
|
||||
embedding(input_ids).realize()
|
||||
|
||||
def test_load_state_dict(self):
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
|
||||
@@ -83,33 +83,30 @@ class TestSchedule(unittest.TestCase):
|
||||
np.testing.assert_allclose(t.numpy(), torch_out)
|
||||
|
||||
def test_arange_avgpool2d_fused_noopt(self):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=1): self.test_arange_avgpool2d(kcount=1)
|
||||
with Context(NOOPT=1): self.test_arange_avgpool2d(kcount=1)
|
||||
|
||||
# linearizer error
|
||||
@unittest.skip("recursion error no longer raised")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "needs supports_float4 to fail")
|
||||
def test_arange_avgpool2d_fused(self):
|
||||
with self.assertRaises(RecursionError):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=0): self.test_arange_avgpool2d(kcount=1)
|
||||
with Context(NOOPT=0): self.test_arange_avgpool2d(kcount=1)
|
||||
|
||||
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
|
||||
# all permutes, reshapes, expands and shrinks push through the reduce
|
||||
def test_arange_sum(self):
|
||||
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(a, 1))
|
||||
run_schedule(check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [1, 5, 9])
|
||||
|
||||
def test_arange_sum_alt(self):
|
||||
a = (Tensor.arange(5).reshape(1,5).expand(6,5)*Tensor(2)).reshape(1,6,5).sum(axis=2)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(a, 1))
|
||||
run_schedule(check_schedule(a, 1))
|
||||
np.testing.assert_equal(a.numpy(), 20)
|
||||
|
||||
def test_permute_arange(self):
|
||||
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(a, 1))
|
||||
run_schedule(check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [[15]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
|
||||
@@ -137,8 +134,7 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_indexing_scalars_simple(self):
|
||||
X = Tensor.randn(2, 2).realize()
|
||||
xt = X[Tensor(1)][Tensor(0)]
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
|
||||
@@ -158,8 +154,7 @@ class TestSchedule(unittest.TestCase):
|
||||
assume(a<x and b<y)
|
||||
X = Tensor.randn(x, y).realize()
|
||||
xt = X[Tensor(a)][Tensor(b)]
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[a][b])
|
||||
|
||||
def test_push_pads_elementwise(self):
|
||||
@@ -1574,8 +1569,7 @@ class TestSchedule(unittest.TestCase):
|
||||
x = Tensor.empty(3,3,3,3, requires_grad=True)
|
||||
y = x.pad((-1,2,2,-1), mode="replicate")
|
||||
dx = y.sum().gradient(x)[0]
|
||||
with Context(FUSE_ARANGE=1):
|
||||
sched = check_schedule(dx, 3)
|
||||
sched = check_schedule(dx, 3)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3)
|
||||
|
||||
@@ -1876,8 +1870,7 @@ class TestSchedule(unittest.TestCase):
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
args = {"dim":32 if CI else 128, "end":2048 if CI else 8192, "theta":10000}
|
||||
fused = precompute_freqs_cis(**args)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(fused, 3))
|
||||
run_schedule(check_schedule(fused, 3))
|
||||
if getenv("CHECK", 1):
|
||||
ref = precompute_freqs_cis(**args)
|
||||
run_schedule(check_schedule(ref, 3))
|
||||
@@ -1961,15 +1954,6 @@ class TestSchedule(unittest.TestCase):
|
||||
np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7)
|
||||
np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7)
|
||||
|
||||
@unittest.skip("multi output isn't supported")
|
||||
def test_multiview_arange_children(self):
|
||||
X = Tensor.randn(2,3,4,4).numpy()
|
||||
with Context(FUSE_ARANGE=1):
|
||||
compare = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
|
||||
with Context(FUSE_ARANGE=0, TRACK_MATCH_STATS=0):
|
||||
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
|
||||
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_recursive_swizzle(self):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestStunning(unittest.TestCase):
|
||||
Y_train = Y_train.one_hot(10)
|
||||
X_samp, Y_samp = X_train[samples], Y_train[samples]
|
||||
vi = Variable('i', 0, samples.shape[0]-1)
|
||||
with Context(FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
with Tensor.train():
|
||||
losses = []
|
||||
for i in range(samples.shape[0]):
|
||||
|
||||
@@ -133,7 +133,7 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa
|
||||
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
||||
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
||||
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 1), ContextVar("FUSE_CONV_BW", 0)
|
||||
FUSE_CONV_BW = ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||
|
||||
@@ -2,7 +2,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.helpers import all_int, all_same, prod, dedup, unwrap, getenv, pluralize, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
@@ -250,7 +250,7 @@ def do_fusion(x:UOp):
|
||||
|
||||
def fuse_arange(root:UOp):
|
||||
# skip if root is arange
|
||||
if not FUSE_ARANGE or root.src[0].base.op is Ops.CONST: return None
|
||||
if root.src[0].base.op is Ops.CONST: return None
|
||||
# gather all local aranges (including any fused ones)
|
||||
local_arange: list[UOp] = []
|
||||
def gate_reduce(u):
|
||||
|
||||
Reference in New Issue
Block a user