mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up more RANGEIFY flag (#12556)
This commit is contained in:
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import temp, RANGEIFY
|
||||
from tinygrad.helpers import temp
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
@@ -300,7 +300,6 @@ class TestAssign(unittest.TestCase):
|
||||
#assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
@unittest.skipUnless(RANGEIFY, "only correct in rangeify")
|
||||
def test_post_permuted_assignment_alt(self):
|
||||
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
||||
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
|
||||
|
||||
@@ -3,7 +3,6 @@ from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
from test.helpers import not_support_multi_device
|
||||
@@ -158,8 +157,7 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
|
||||
|
||||
def test_add_padded_zero(self):
|
||||
# TODO: it's 1 now, this might be possible to fold
|
||||
_check_ast_count(0 if RANGEIFY else 1, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
|
||||
|
||||
def test_mul_shrunk_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
|
||||
@@ -168,16 +166,16 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
|
||||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: RANGEIFY or not, it's always 1 kernel when calling .numpy, limitation of _check_ast_count
|
||||
# NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_check_ast_count(1 if RANGEIFY else 0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
if is_dtype_supported(dtypes.uint16):
|
||||
_check_ast_count(1 if RANGEIFY else 0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# folded
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_check_ast_count(1 if RANGEIFY else 0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
|
||||
class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
@@ -249,7 +247,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
|
||||
assert t.sum().dtype == t.contiguous().sum().dtype
|
||||
|
||||
@unittest.skipIf(not_support_multi_device() or RANGEIFY, "no multi, RANGEIFY doesn't support multi const folding")
|
||||
@unittest.skipIf(not_support_multi_device() or True, "no multi, RANGEIFY doesn't support multi const folding")
|
||||
class TestMultiConstFolding(unittest.TestCase):
|
||||
def test_multi_const_folding_literal(self):
|
||||
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad import Device, dtypes, Tensor, Context
|
||||
from tinygrad.device import LRUAllocator, is_dtype_supported
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.helpers import prod, unwrap, RANGEIFY
|
||||
from tinygrad.helpers import prod, unwrap
|
||||
from test.helpers import REAL_DEV
|
||||
|
||||
IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL")
|
||||
@@ -139,7 +139,7 @@ class TestImageDType(unittest.TestCase):
|
||||
# NOTE: the w1 grad must realize to a seperate kernel
|
||||
assert w1.grad.uop.is_realized, f"never realized {w1.grad}"
|
||||
self.assertEqual(w1.grad.uop.base.buffer.dtype, dtypes.float32)
|
||||
self.assertEqual(len(sched), 9 if RANGEIFY else 10)
|
||||
self.assertEqual(len(sched), 9)
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageRealization(unittest.TestCase):
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, RANGEIFY
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT
|
||||
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
|
||||
@@ -314,7 +314,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a.realize()
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
@unittest.skipIf(RANGEIFY and isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
|
||||
def test_where_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
|
||||
@@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings
|
||||
import numpy as np
|
||||
from typing import List, Callable
|
||||
import torch
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM, RANGEIFY
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -3040,7 +3040,6 @@ class TestOps(unittest.TestCase):
|
||||
pos_weight=torch.tensor(pos_weight)),
|
||||
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1),pos_weight=Tensor(pos_weight)))
|
||||
|
||||
@unittest.skipIf(RANGEIFY > 1, "broken on RANGEIFY > 1, TODO: fix")
|
||||
def test_cross_entropy_class_probabilities(self):
|
||||
helper_test_op([(32,), (32,)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y))
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, Context, Device
|
||||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
from tinygrad.helpers import DEBUG, get_single_element, RANGEIFY
|
||||
from tinygrad.helpers import DEBUG, get_single_element
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -39,17 +39,17 @@ class TestFuse(unittest.TestCase):
|
||||
np_multi = fxn(*args, **kwargs).numpy()
|
||||
np.testing.assert_allclose(np_single, np_multi, atol=atol)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_fuse_norm(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a / a.mean(axis=1), a)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_fuse_argmax(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.argmax(axis=-1), a)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_fuse_softmax(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.softmax(axis=-1), a)
|
||||
@@ -60,7 +60,7 @@ class TestFuse(unittest.TestCase):
|
||||
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_fuse_softmax_dtype(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4)
|
||||
@@ -68,7 +68,7 @@ class TestFuse(unittest.TestCase):
|
||||
def test_fuse_arange_eye(self):
|
||||
self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10))
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_double_gemm(self):
|
||||
N = 32
|
||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
||||
@@ -91,7 +91,7 @@ class TestFuse(unittest.TestCase):
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
self._test_fuse(embedding, a, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_attention_kernel_count(self):
|
||||
wq = Tensor.empty(32, 32)
|
||||
wk = Tensor.empty(32, 32)
|
||||
@@ -104,7 +104,7 @@ class TestFuse(unittest.TestCase):
|
||||
s = attn.schedule()
|
||||
self.assertEqual(len(s), 4) # 3 matmul and 1 attention
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
HEADS = 2
|
||||
@@ -172,7 +172,7 @@ class TestSoftmaxFusion(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
@unittest.skip("needs RANGEIFY>1")
|
||||
def test_auto_softmax(self):
|
||||
print("*** softmax ***")
|
||||
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.apps.llm import apply_rope
|
||||
#from tinygrad.engine.realize import run_schedule
|
||||
|
||||
# TODO: test_scheduler, but just in uint
|
||||
class TestAttention(unittest.TestCase):
|
||||
@unittest.skipIf(RANGEIFY > 0, "not half on rangeify")
|
||||
def test_half_qkv_buffers(self):
|
||||
BS, seqlen, dim = 10, 4, 100
|
||||
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
@@ -14,12 +12,11 @@ class TestAttention(unittest.TestCase):
|
||||
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
attn = q.scaled_dot_product_attention(k, v)
|
||||
sched = attn.schedule()
|
||||
#run_schedule(sched[:])
|
||||
# attention has 5 kernels now
|
||||
self.assertEqual(len(sched), 4 if RANGEIFY else 5)
|
||||
softmax_inputs = sched[1:4]
|
||||
for i,si in enumerate(softmax_inputs):
|
||||
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
|
||||
# attention has 4 kernels now
|
||||
self.assertEqual(len(sched), 4)
|
||||
# softmax_inputs = sched[1:4]
|
||||
# for i,si in enumerate(softmax_inputs):
|
||||
# assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
|
||||
|
||||
def test_apply_rope(self):
|
||||
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, sys
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes, Context, nn
|
||||
from tinygrad.helpers import CI, Profiling, WINO, RANGEIFY
|
||||
from tinygrad.helpers import CI, Profiling, WINO
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith("win"), "flaky on Windows")
|
||||
class TestWinogradClose(unittest.TestCase):
|
||||
@@ -61,9 +61,9 @@ class TestWinograd(unittest.TestCase):
|
||||
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
|
||||
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
|
||||
|
||||
if not RANGEIFY:
|
||||
self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now
|
||||
self.assertLess(mem_ratio, 10)
|
||||
# TODO: what's optimal on this?
|
||||
self.assertLess(ops_ratio, 4.3)
|
||||
self.assertLess(mem_ratio, 3)
|
||||
|
||||
def test_dtype(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, RANGEIFY
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -38,11 +38,10 @@ rewrites_for_linearizer = [
|
||||
|
||||
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
|
||||
# cache with the values of the context vars
|
||||
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value, RANGEIFY.value)
|
||||
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
||||
|
||||
@functools.cache
|
||||
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL,
|
||||
_RANGEIFY) -> list[RewriteStep]:
|
||||
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
@@ -52,8 +51,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
|
||||
# split ranges
|
||||
if _RANGEIFY:
|
||||
ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges"))
|
||||
ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges"))
|
||||
|
||||
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
||||
ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic"))
|
||||
|
||||
Reference in New Issue
Block a user