delete DONT_REALIZE_EXPAND and DONT_GROUP_REDUCES (#12744)

does nothing now
This commit is contained in:
chenyu
2025-10-16 14:11:33 -04:00
committed by GitHub
parent 98239f1156
commit 285534ce64
7 changed files with 32 additions and 50 deletions

View File

@@ -642,7 +642,7 @@ jobs:
ln -s /data/home/tiny/tinygrad/testsig-*.so .
PYTHONPATH=. CC=clang-19 CPU=1 CPU_LLVM=0 QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
# benchmark on DSP with NOOPT=1, the devectorizer has issues
PYTHONPATH=. CC=clang-19 DSP=1 DONT_REALIZE_EXPAND=1 NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
PYTHONPATH=. CC=clang-19 DSP=1 NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
- uses: actions/upload-artifact@v4

View File

@@ -19,8 +19,8 @@ from tinygrad.helpers import fetch, getenv
# QUANT=1 python3 examples/test_onnx_imagenet.py
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# DONT_REALIZE_EXPAND=1 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
# python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
# VIZ=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
def imagenet_dataloader(cnt=0):
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)

View File

@@ -1,4 +1,4 @@
from tinygrad import Tensor, dtypes, Context, GlobalCounters
from tinygrad import Tensor, dtypes, GlobalCounters
dtypes.default_float = dtypes.float16
from tinygrad.dtype import to_dtype
from tinygrad.helpers import getenv
@@ -13,6 +13,5 @@ if __name__ == "__main__":
# test single kernel softmax
GlobalCounters.reset()
with Context(DONT_GROUP_REDUCES=1):
single_kernel_softmax(t, -1, acc_dtype).realize()
single_kernel_softmax(t, -1, acc_dtype).realize()

View File

@@ -72,7 +72,7 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
out_file = get_quantized_model(sz)
run_onnx = OnnxRunner(out_file)
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):
with Context(QUANTIZE=1):
sched = run_onnx({"input":inp})["output"].schedule()
ei = lower_schedule_item(sched[-2])
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_REG]
@@ -86,8 +86,7 @@ class TestQuantizeOnnx(unittest.TestCase):
# divide is ~1500-2000 without reduce_range, 750-900 with it
out_file = get_quantized_model(sz)
run_onnx_jit, _ = load_onnx_model(out_file)
with Context(DONT_REALIZE_EXPAND=1):
run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))
run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))
def test_prequant_conv2d_1x1(self):
X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8))
@@ -109,11 +108,10 @@ class TestQuantizeOnnx(unittest.TestCase):
N = 512
X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(xi))
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(wi))
with Context(DONT_REALIZE_EXPAND=1):
# this divide is interesting and forces the accumulator to actually be an int
out = (X.cast("int").matmul(W.cast("int"))//1000).cast("int8")
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)
# this divide is interesting and forces the accumulator to actually be an int
out = (X.cast("int").matmul(W.cast("int"))//1000).cast("int8")
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)
def test_prequant_gemm_handcode(self):
src = """typedef int int128 __attribute__((aligned(512),vector_size(512)));
@@ -203,14 +201,12 @@ class TestQuantizeOnnx(unittest.TestCase):
def test_prequant_gemm_intacc(self, xi=np.uint8, wi=np.uint8, replace_src=None, N=512, clip=True, opts=None):
X = Tensor(m1:=(np.random.uniform(0, 255, size=(N,N)).astype(xi))).realize()
W = Tensor(m2:=(np.random.uniform(0, 255, size=(N,N)).astype(wi))).realize()
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
tg_dtype = dtypes.int8 if xi == np.int8 else dtypes.uint8
with Context(DONT_REALIZE_EXPAND=1):
out = (X.int().matmul(W.int())//1000)
if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
out = out.cast(tg_dtype)
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts
sexec(out, opts, replace_src, run_count=1)
out = (X.int().matmul(W.int())//1000)
if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
out = out.cast(tg_dtype)
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts
sexec(out, opts, replace_src, run_count=1)
tout = out.numpy()
mout = ((m1.astype(np.int32) @ m2.astype(np.int32)) // 1000)
if clip: mout = mout.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
@@ -225,7 +221,6 @@ class TestQuantizeOnnx(unittest.TestCase):
def test_prequant_gemv(self):
N = 2048
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
X = Tensor(np.random.uniform(0, 255, size=(1,N)).astype(np.uint8)).realize()
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)).realize()
#out = X.cast(dtypes.int) @ W.cast(dtypes.int)

View File

@@ -171,8 +171,7 @@ class TestSchedule(unittest.TestCase):
def test_rand_recompute_arange(self):
x = Tensor.rand(32)
with Context(DONT_GROUP_REDUCES=1):
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
def test_empty_is_not_realized(self):
a = Tensor.empty(10)
@@ -276,7 +275,7 @@ class TestSchedule(unittest.TestCase):
a = Tensor.randn(10,10,10).realize()
b = Tensor.randn(10,10,1).realize()
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(c, 1))
run_schedule(check_schedule(c, 1))
np.testing.assert_allclose(c.numpy(), np.sum(a.numpy(), axis=0, keepdims=True).transpose(2,1,0)+b.numpy())
def test_binop_early_reshape_reduce_fusion(self):
@@ -1976,8 +1975,7 @@ class TestSwizzle(unittest.TestCase):
a = Tensor.randint(32, 32).realize()
r = (a+a).sum(1).sum(0)
# double reduce collapses to a single reduce
with Context(DONT_GROUP_REDUCES=1):
run_schedule(check_schedule(r, 1))
run_schedule(check_schedule(r, 1))
self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0))
def test_single_swizzle(self):
@@ -1997,33 +1995,29 @@ class TestSwizzle(unittest.TestCase):
b = Tensor.randint(4,).realize()
# parallel reduce!
add = a.sum(0)+b.sum(0)
with Context(DONT_GROUP_REDUCES=1):
run_schedule(check_schedule(add, 1))
run_schedule(check_schedule(add, 1))
self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0))
@unittest.skip("TODO: how do we express the norm")
def test_softmax_one_kernel(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randn(32, 32).realize()
t = a.softmax()
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1):
check_schedule(t, 1)
check_schedule(t, 1)
def test_argmax_one_kernel(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randn(10, 20).realize()
t = a.argmax(0)
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): t.realize()
check_schedule(t, 1)
def test_swizzle_reduceop(self):
Tensor.manual_seed(0)
x = Tensor.randn(4,4).realize()
y = Tensor.randn(4,4,4).realize()
out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y
with Context(DONT_REALIZE_EXPAND=1, DONT_GROUP_REDUCES=1):
run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy())
def test_permute_rewrite(self):
@@ -2031,7 +2025,7 @@ class TestSwizzle(unittest.TestCase):
y = Tensor.randn(4, 1, 16).realize()
z = Tensor.randn(4, 4, 1).realize()
t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
run_schedule(check_schedule(t, 1))
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
@@ -2042,14 +2036,14 @@ class TestSwizzle(unittest.TestCase):
a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,))
b_reduce = b.sum(axis=(0,))
t = a_reduce+b_reduce
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
run_schedule(check_schedule(t, 1))
def test_parallel_reduce_possible(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 2, 2).realize()
y = Tensor.randn(4, 2, 2).realize()
t = x.sum(axis=1)+y.sum(axis=1)
with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(t, 1))
run_schedule(check_schedule(t, 1))
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
# kernels can only have 1 or n in each dim
@@ -2058,7 +2052,7 @@ class TestSwizzle(unittest.TestCase):
x = Tensor.randn(4, 2, 2).realize()
y = Tensor.randn(4, 3, 2).realize()
t = x.sum(axis=1)+y.sum(axis=1)
with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(t, 1))
run_schedule(check_schedule(t, 1))
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
def test_unsafe_pad(self):

View File

@@ -165,8 +165,7 @@ class TestSoftmaxFusion(unittest.TestCase):
sout.realize()
print("*** single kernel softmax ***")
# NOTE: DONT_GROUP_REDUCES is required here
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1):
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
out = single_kernel_softmax(self.test)
out.realize()
@@ -186,7 +185,6 @@ class TestSoftmaxFusion(unittest.TestCase):
np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7)
@unittest.skip("recursion error no longer raised")
def test_softmax_bw(self):
print("*** softmax bw ***")
self.test.requires_grad_()
@@ -197,14 +195,11 @@ class TestSoftmaxFusion(unittest.TestCase):
self.test.grad = None
print("*** single kernel softmax bw ***")
# NOTE: DONT_GROUP_REDUCES is required here
# TODO: fix RecursionError with DONT_GROUP_REDUCES
with self.assertRaises(RecursionError):
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1):
single_kernel_softmax(self.test).sum().backward()
g = self.test.grad.realize()
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
single_kernel_softmax(self.test).sum().backward()
g = self.test.grad.realize()
np.testing.assert_allclose(sg.numpy(), g.numpy(), atol=1e-7)
np.testing.assert_allclose(sg.numpy(), g.numpy(), atol=1e-7)
if __name__ == '__main__':
unittest.main()

View File

@@ -158,7 +158,6 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte
PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
DISABLE_COMPILER_CACHE, BLOCK_REORDER = ContextVar("DISABLE_COMPILER_CACHE", 0), ContextVar("BLOCK_REORDER", 1)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)