remove trivial use of RANGEIFY flag (#12550)

some tests need update still
This commit is contained in:
chenyu
2025-10-09 14:29:38 +08:00
committed by GitHub
parent 80d99d52a5
commit ae51bdd06a
17 changed files with 86 additions and 132 deletions

View File

@@ -80,7 +80,6 @@ print("******** third, the UOp ***********")
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.helpers import RANGEIFY
from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.rangeify import get_rangeify_map
# allocate some values + load in values # allocate some values + load in values

View File

@@ -49,7 +49,6 @@ def rangeify_kernel3():
b = Tensor.empty(N,N) b = Tensor.empty(N,N)
c = a@b c = a@b
#c = c.reshape((32,2,16,4,32,2,16,4)).contiguous() #c = c.reshape((32,2,16,4,32,2,16,4)).contiguous()
with Context(RANGEIFY=1):
sink = c.schedule()[-1].ast sink = c.schedule()[-1].ast
#print(sink) #print(sink)
@@ -329,7 +328,7 @@ if __name__ == "__main__":
elif HL == 1: hprg = hl_spec_kernel3() elif HL == 1: hprg = hl_spec_kernel3()
else: hprg = hand_spec_kernel3() else: hprg = hand_spec_kernel3()
if HL == 3: if HL == 3:
with Context(RANGEIFY=1, BLOCK_REORDER=0): with Context(BLOCK_REORDER=0):
prg = get_program(hprg, Device.default.renderer) prg = get_program(hprg, Device.default.renderer)
else: else:
prg = get_program(hprg, Device.default.renderer) prg = get_program(hprg, Device.default.renderer)

View File

@@ -4,7 +4,7 @@ import numpy as np
import torch import torch
from tinygrad import GlobalCounters, Tensor, Device from tinygrad import GlobalCounters, Tensor, Device
from tinygrad.helpers import getenv, RANGEIFY from tinygrad.helpers import getenv
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import capturing from tinygrad.engine.realize import capturing
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
@@ -164,7 +164,7 @@ class TestOpt(unittest.TestCase):
def test_permute_was_pushed(self): def test_permute_was_pushed(self):
a = Tensor.randn(16, 16, 16) a = Tensor.randn(16, 16, 16)
with CLCache(1 if RANGEIFY else 2): with CLCache(1):
c = a.sum(2) c = a.sum(2)
d = c.permute(1,0).contiguous() d = c.permute(1,0).contiguous()
d.realize() d.realize()
@@ -172,7 +172,7 @@ class TestOpt(unittest.TestCase):
def test_permute_was_pushed_through_contract_reshape(self): def test_permute_was_pushed_through_contract_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4) a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache(1 if RANGEIFY else 2): with CLCache(1):
c = a.sum(-1) c = a.sum(-1)
d = c.reshape(16,16).permute(1,0).contiguous() d = c.reshape(16,16).permute(1,0).contiguous()
d.realize() d.realize()
@@ -180,7 +180,7 @@ class TestOpt(unittest.TestCase):
def test_permute_was_pushed_through_contractw1s_reshape(self): def test_permute_was_pushed_through_contractw1s_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4) a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache(1 if RANGEIFY else 2): with CLCache(1):
c = a.sum(-1) c = a.sum(-1)
d = c.reshape(16,1,16).permute(2,1,0).contiguous() d = c.reshape(16,1,16).permute(2,1,0).contiguous()
d.realize() d.realize()
@@ -188,7 +188,7 @@ class TestOpt(unittest.TestCase):
def test_permute_was_pushed_through_expand_reshape(self): def test_permute_was_pushed_through_expand_reshape(self):
a = Tensor.randn(16, 16, 16) a = Tensor.randn(16, 16, 16)
with CLCache(1 if RANGEIFY else 2): with CLCache(1):
c = a.sum(2) c = a.sum(2)
d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous() d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous()
d.realize() d.realize()
@@ -220,7 +220,7 @@ class TestOpt(unittest.TestCase):
for axis in [0, 1]: for axis in [0, 1]:
for n in [4, 8, 16]: for n in [4, 8, 16]:
b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis)
with CLCache(allowed=3 if RANGEIFY else 2): with CLCache(allowed=3):
a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis) a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis)
a.realize() a.realize()
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
@@ -229,7 +229,7 @@ class TestOpt(unittest.TestCase):
axis1, axis2 = 0, 1 axis1, axis2 = 0, 1
for n in [4, 8, 16]: for n in [4, 8, 16]:
b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
with CLCache(allowed=3 if RANGEIFY else 2): with CLCache(allowed=3):
a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2)
a.realize() a.realize()
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)

View File

@@ -1,4 +1,4 @@
import time, struct, unittest import time, struct
from typing import Any, Callable from typing import Any, Callable
import numpy as np import numpy as np
from tinygrad import Tensor, dtypes, Device from tinygrad import Tensor, dtypes, Device
@@ -7,7 +7,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner from tinygrad.engine.realize import Runner
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, CI, RANGEIFY from tinygrad.helpers import T, CI
from tinygrad.codegen import full_rewrite from tinygrad.codegen import full_rewrite
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler
@@ -62,6 +62,3 @@ def not_support_multi_device():
# NOTE: This will open REMOTE if it's the default device # NOTE: This will open REMOTE if it's the default device
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device) REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)
def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn)
def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn)

View File

@@ -1,6 +1,6 @@
import unittest import unittest
from tinygrad import Device, Tensor, dtypes from tinygrad import Device, Tensor, dtypes
from tinygrad.helpers import CI, RANGEIFY from tinygrad.helpers import CI
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
# TODO: write a clean version of this # TODO: write a clean version of this
@@ -351,7 +351,6 @@ class TestKernelOpts(unittest.TestCase):
] + [[Opt(OptOps.THREAD, 0, 4)] if Device[Device.DEFAULT].renderer.global_max[0] >= 4 else []] ] + [[Opt(OptOps.THREAD, 0, 4)] if Device[Device.DEFAULT].renderer.global_max[0] >= 4 else []]
+ [[Opt(OptOps.THREAD, 0, 8)] if Device[Device.DEFAULT].renderer.global_max[0] >= 8 else []]) + [[Opt(OptOps.THREAD, 0, 8)] if Device[Device.DEFAULT].renderer.global_max[0] >= 8 else []])
@unittest.skipUnless(RANGEIFY>=1, "Kernel only fuses with rangeify")
def test_double_sum_group(self): def test_double_sum_group(self):
a = Tensor.rand(4, 4, 4) a = Tensor.rand(4, 4, 4)
r = a.sum((1, 2)).sum() r = a.sum((1, 2)).sum()

View File

@@ -1,7 +1,7 @@
import unittest import unittest
import numpy as np import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
from tinygrad.helpers import CI, Context, getenv, RANGEIFY from tinygrad.helpers import CI, Context, getenv
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops
@@ -95,7 +95,7 @@ class TestIndexing(unittest.TestCase):
X = dataset[idxs] X = dataset[idxs]
assert X.shape == (4,DDIM) assert X.shape == (4,DDIM)
sched = X.schedule() sched = X.schedule()
self.assertEqual(len(sched), 1 if RANGEIFY else 2) self.assertEqual(len(sched), 1)
run_schedule(sched) run_schedule(sched)
assert GlobalCounters.global_ops < 4*DSET, f"too many ops {GlobalCounters.global_ops} != {4*DSET}" assert GlobalCounters.global_ops < 4*DSET, f"too many ops {GlobalCounters.global_ops} != {4*DSET}"
np.testing.assert_allclose(real_index, X.numpy()) np.testing.assert_allclose(real_index, X.numpy())

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest import unittest
import contextlib
import numpy as np import numpy as np
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
@@ -271,8 +270,6 @@ class TestAssign(unittest.TestCase):
b.assign(a.contiguous()).realize() b.assign(a.contiguous()).realize()
assert GlobalCounters.kernel_count - kc == 2 assert GlobalCounters.kernel_count - kc == 2
# passing in RANGEIFY=1, RANGEIFY=0 asserts permuted assigns it can't fuse
def assert_permuted_assign(self): return self.assertRaisesRegex(RuntimeError, "contiguous") if not RANGEIFY else contextlib.nullcontext()
def test_permuted_assignment(self): def test_permuted_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
@@ -280,7 +277,6 @@ class TestAssign(unittest.TestCase):
b.realize() b.realize()
ba1 = a.uop.base.realized ba1 = a.uop.base.realized
bb1 = b.uop.base.realized bb1 = b.uop.base.realized
with self.assert_permuted_assign():
a = a.permute(1,0) a = a.permute(1,0)
a += b a += b
a.realize() a.realize()
@@ -297,7 +293,6 @@ class TestAssign(unittest.TestCase):
#GlobalCounters.cache = [] #GlobalCounters.cache = []
ba1 = a.uop.base.realized # noqa: F841 ba1 = a.uop.base.realized # noqa: F841
bb1 = b.uop.base.realized # noqa: F841 bb1 = b.uop.base.realized # noqa: F841
with self.assert_permuted_assign():
a.assign(a.permute(1,0) + b) # this should not work! a.assign(a.permute(1,0) + b) # this should not work!
a.realize() a.realize()
ba2 = a.uop.base.realized # noqa: F841 ba2 = a.uop.base.realized # noqa: F841
@@ -345,8 +340,6 @@ class TestAssign(unittest.TestCase):
def test_permuted_assignment_correct(self): def test_permuted_assignment_correct(self):
a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize() a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize() b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
# TODO: swizzler.py limitation, should NOT raise AssertionError from numpy.
with self.assert_permuted_assign():
a = a.permute(1, 0) a = a.permute(1, 0)
new_val = a + b new_val = a + b
a.assign(new_val) a.assign(new_val)
@@ -355,7 +348,6 @@ class TestAssign(unittest.TestCase):
def test_permuted_reduceop_child_dual_use(self): def test_permuted_reduceop_child_dual_use(self):
a = Tensor.randn(32, 32, 32).realize() a = Tensor.randn(32, 32, 32).realize()
b = Tensor.full((32, 32), 1.).contiguous().realize() b = Tensor.full((32, 32), 1.).contiguous().realize()
with self.assert_permuted_assign():
r = a.sum(axis=1) r = a.sum(axis=1)
b.assign(r + b.permute(1, 0)) b.assign(r + b.permute(1, 0))
b.realize() b.realize()
@@ -401,7 +393,6 @@ class TestAssign(unittest.TestCase):
def test_permuted_assignment_masked_view_not_contiguous(self): def test_permuted_assignment_masked_view_not_contiguous(self):
a = Tensor.ones(4, 4).contiguous().realize() a = Tensor.ones(4, 4).contiguous().realize()
with self.assert_permuted_assign():
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0) b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0)
a.assign(a + b) a.assign(a + b)
a.realize() a.realize()

View File

@@ -3164,8 +3164,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf)) helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf)) helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
@unittest.skipIf(RANGEIFY and (getenv("MOCKGPU") or Device.DEFAULT == "PYTHON"), "very slow on MOCKGPU because reduce does not fold") @unittest.skipIf((getenv("MOCKGPU") or Device.DEFAULT == "PYTHON"), "very slow on MOCKGPU because reduce does not fold")
@unittest.skipIf(RANGEIFY and Device.DEFAULT == "WEBGPU", "webgpu runtime issue") @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu runtime issue")
def test_masked_select(self): def test_masked_select(self):
helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True) helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True)
helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True) helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True)

View File

@@ -1,9 +1,8 @@
import unittest import unittest
from tinygrad import Tensor, nn from tinygrad import Tensor, nn
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters from tinygrad.helpers import Context, GlobalCounters
from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
class TestRangeifyAssign(unittest.TestCase): class TestRangeifyAssign(unittest.TestCase):
def test_assign_permuted(self): def test_assign_permuted(self):
A = Tensor.empty(4, 4, dtype='int') A = Tensor.empty(4, 4, dtype='int')
@@ -55,7 +54,6 @@ class TestRangeifyOpt(unittest.TestCase):
A = Tensor.empty(8,8,8,8).permute(1,0,3,2).flatten() A = Tensor.empty(8,8,8,8).permute(1,0,3,2).flatten()
A.sum().realize() A.sum().realize()
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
class TestRangeify(unittest.TestCase): class TestRangeify(unittest.TestCase):
def test_groupnorm(self): def test_groupnorm(self):
# ranges 1 and 3 are merging # ranges 1 and 3 are merging
@@ -230,7 +228,6 @@ class TestRangeify(unittest.TestCase):
# contiguous + reduce can support ranges? # contiguous + reduce can support ranges?
@unittest.skip("okay to disable this for now") @unittest.skip("okay to disable this for now")
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
class TestOuterworld(unittest.TestCase): class TestOuterworld(unittest.TestCase):
def test_passthrough_range(self): def test_passthrough_range(self):
t = Tensor.rand(10, 10).realize() t = Tensor.rand(10, 10).realize()

View File

@@ -17,7 +17,6 @@ from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context,
from tinygrad.schedule.rangeify import get_rangeify_map, Kernel from tinygrad.schedule.rangeify import get_rangeify_map, Kernel
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails
class KernelCountException(Exception): pass class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True): def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
@@ -117,7 +116,7 @@ class TestSchedule(unittest.TestCase):
a = Tensor.empty(10) a = Tensor.empty(10)
b = Tensor.empty((1,), device="CPU").expand(10).contiguous() b = Tensor.empty((1,), device="CPU").expand(10).contiguous()
c = a+b c = a+b
with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2 if RANGEIFY else 1) with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2)
@unittest.skipUnless(is_dtype_supported(dtypes.half) and getenv("CAST_AFTER_EXPAND"), "need half and CAST_AFTER_EXPAND=1") @unittest.skipUnless(is_dtype_supported(dtypes.half) and getenv("CAST_AFTER_EXPAND"), "need half and CAST_AFTER_EXPAND=1")
@unittest.skip("CAST_AFTER_EXPAND is not supported") @unittest.skip("CAST_AFTER_EXPAND is not supported")
@@ -343,7 +342,7 @@ class TestSchedule(unittest.TestCase):
r1 = (x - r0).sum(axis=0).div(2) r1 = (x - r0).sum(axis=0).div(2)
out0 = r0 + y out0 = r0 + y
out1 = r1 + y out1 = r1 + y
schedule = check_schedule([out0, out1], 2 if RANGEIFY else 4) schedule = check_schedule([out0, out1], 2)
reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}] reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}]
assert len(reduceops) in [2,3] # why is RANGEIFY different? assert len(reduceops) in [2,3] # why is RANGEIFY different?
@@ -712,7 +711,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(b, 0) check_schedule(b, 0)
self.assertEqual(b.item(), 1) self.assertEqual(b.item(), 1)
@expect_rangeify_fails @unittest.expectedFailure
def test_multioutput_ast(self): def test_multioutput_ast(self):
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
@@ -919,7 +918,7 @@ class TestSchedule(unittest.TestCase):
out0 = a.sum() + 2 out0 = a.sum() + 2
out1 = a.sum() + 4 out1 = a.sum() + 4
out2 = out0 * out1 out2 = out0 * out1
run_schedule(check_schedule([out0, out1, out2], 1 if RANGEIFY else 4)) run_schedule(check_schedule([out0, out1, out2], 1))
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
@@ -930,7 +929,7 @@ class TestSchedule(unittest.TestCase):
out0 = a.sum().exp2() out0 = a.sum().exp2()
# out1 has two paths to a.sum() # out1 has two paths to a.sum()
out1 = a.sum() + out0 out1 = a.sum() + out0
run_schedule(check_schedule([out0, out1], 1 if RANGEIFY else 3)) run_schedule(check_schedule([out0, out1], 1))
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6)
@@ -1022,7 +1021,7 @@ class TestSchedule(unittest.TestCase):
b = Tensor.empty(10,) b = Tensor.empty(10,)
c = a.sum() + b[0] c = a.sum() + b[0]
d = a.sum() + 2 d = a.sum() + 2
check_schedule([c, d], 1 if RANGEIFY else 3) check_schedule([c, d], 1)
def test_reduce_multiple_paths_midshrink(self): def test_reduce_multiple_paths_midshrink(self):
a = Tensor.empty(4, 4) a = Tensor.empty(4, 4)
@@ -1186,14 +1185,14 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@expect_rangeify_fails @unittest.expectedFailure
def test_softmax_upcast(self): def test_softmax_upcast(self):
# input half, softmax in float # input half, softmax in float
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize() x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize()
out = x.softmax(dtype=dtypes.float) out = x.softmax(dtype=dtypes.float)
sched = out.schedule() sched = out.schedule()
self.assertEqual(len(sched), 2 if RANGEIFY else 3) self.assertEqual(len(sched), 2)
self.assertEqual(sched[0].bufs[0].dtype, dtypes.half) self.assertEqual(sched[0].bufs[0].dtype, dtypes.half)
# input float, softmax in float # input float, softmax in float
@@ -1323,7 +1322,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 14) check_schedule(opt.schedule_step(), 14)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@expect_rangeify_fails @unittest.expectedFailure
def test_prefer_half_buffer(self): def test_prefer_half_buffer(self):
x = Tensor.ones(4).contiguous().realize() x = Tensor.ones(4).contiguous().realize()
# y = Tensor.ones(4).contiguous().realize() # y = Tensor.ones(4).contiguous().realize()
@@ -1475,7 +1474,7 @@ class TestSchedule(unittest.TestCase):
e = c * d e = c * d
f = b.sum() - e f = b.sum() - e
# run_schedule(check_schedule([c, d, e, f], 1)) # run_schedule(check_schedule([c, d, e, f], 1))
run_schedule(check_schedule([c, d, e, f], 2 if RANGEIFY else 5)) run_schedule(check_schedule([c, d, e, f], 2))
np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
@@ -1690,7 +1689,7 @@ class TestSchedule(unittest.TestCase):
def test_late_fusion_post_expand(self): def test_late_fusion_post_expand(self):
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2) self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
@expect_rangeify_fails @unittest.expectedFailure
def test_cast_padded_view(self): def test_cast_padded_view(self):
a = Tensor.arange(4).reshape(1, 4) a = Tensor.arange(4).reshape(1, 4)
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float) casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
@@ -1720,7 +1719,7 @@ class TestSchedule(unittest.TestCase):
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all)) @given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
@expect_rangeify_fails @unittest.expectedFailure
def test_cast_padded_const(self, dt1, dt2): def test_cast_padded_const(self, dt1, dt2):
assume(is_dtype_supported(dt1) and is_dtype_supported(dt2)) assume(is_dtype_supported(dt1) and is_dtype_supported(dt2))
a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None)) a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None))
@@ -1891,8 +1890,6 @@ class TestSchedule(unittest.TestCase):
tst = x.shrink((None, (0, 2))).assign(a).realize() tst = x.shrink((None, (0, 2))).assign(a).realize()
xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy() xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy()
np.testing.assert_equal(x.numpy(), xref) np.testing.assert_equal(x.numpy(), xref)
if RANGEIFY > 0:
# NOTE: this is a bug on non rangeify
np.testing.assert_equal(tst.numpy(), a.numpy()) np.testing.assert_equal(tst.numpy(), a.numpy())
def test_setitem_sched(self, mop=lambda x:x, expected_kcount=1): def test_setitem_sched(self, mop=lambda x:x, expected_kcount=1):
@@ -1904,7 +1901,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(sched) run_schedule(sched)
self.assertListEqual(a.tolist(), expected) self.assertListEqual(a.tolist(), expected)
self.assertEqual(kcount, expected_kcount) self.assertEqual(kcount, expected_kcount)
@unittest.skipUnless(RANGEIFY>0, "this asserts on non rangeify")
def test_setitem_permuted_sched(self): self.test_setitem_sched(lambda x: x.T, 2) def test_setitem_permuted_sched(self): self.test_setitem_sched(lambda x: x.T, 2)
def test_setitem_paddded_sched(self): self.test_setitem_sched(lambda x: x.shrink_to(4, 1).pad_to(4, 4), 1) def test_setitem_paddded_sched(self): self.test_setitem_sched(lambda x: x.shrink_to(4, 1).pad_to(4, 4), 1)
@@ -1943,7 +1939,7 @@ class TestSchedule(unittest.TestCase):
r = (X+Tensor.arange(16).reshape(4, 4)).sum() r = (X+Tensor.arange(16).reshape(4, 4)).sum()
out0 = r+2 out0 = r+2
out1 = r+3 out1 = r+3
run_schedule(check_schedule([out0, out1], 1 if RANGEIFY else 3)) run_schedule(check_schedule([out0, out1], 1))
r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum() r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum()
np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7) np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7)
np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7) np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7)
@@ -2088,7 +2084,7 @@ class TestView(unittest.TestCase):
run_schedule(sched) run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0) np.testing.assert_equal(b.numpy(), 0)
@expect_rangeify_fails @unittest.expectedFailure
def test_mask_dim_1(self): def test_mask_dim_1(self):
# mask out dim = 1 works too # mask out dim = 1 works too
a = Tensor.rand(10, 10).realize() a = Tensor.rand(10, 10).realize()
@@ -2236,7 +2232,6 @@ class TestCopyFolding(unittest.TestCase):
b.realize() b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
@expect_nonrangeify_fails
def test_permute_on_disk_contiguous(self): def test_permute_on_disk_contiguous(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
@@ -2251,8 +2246,6 @@ class TestCopyFolding(unittest.TestCase):
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
# NOTE: disk permute must come after COPY # NOTE: disk permute must come after COPY
# TODO: this is wrong because of the permute
@expect_nonrangeify_fails
def test_permute_after_shrink_on_disk(self): def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer()) with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
@@ -2396,12 +2389,8 @@ class TestUOpBecome(unittest.TestCase):
a = Tensor.empty(4, 1) a = Tensor.empty(4, 1)
b = a.expand(4, 4).reciprocal() b = a.expand(4, 4).reciprocal()
check_schedule(b, 1) check_schedule(b, 1)
if RANGEIFY:
self.assertEqual(b.uop.base.buffer.size, 4) self.assertEqual(b.uop.base.buffer.size, 4)
self.assertEqual(b.uop.shape, (4, 4)) self.assertEqual(b.uop.shape, (4, 4))
return
self.assertEqual(b.uop.base.buffer.size, 16)
self.assertEqual(b.uop.st, ShapeTracker.from_shape((4, 4)))
def test_reorder_expand_alt(self): def test_reorder_expand_alt(self):
x = Tensor.empty(4, 1) x = Tensor.empty(4, 1)
@@ -2410,7 +2399,7 @@ class TestUOpBecome(unittest.TestCase):
z = (img*x) / y z = (img*x) / y
check_schedule(z, 1) check_schedule(z, 1)
@expect_rangeify_fails @unittest.expectedFailure
def test_become_existing_buffer(self): def test_become_existing_buffer(self):
a = Tensor.empty(4, 4) a = Tensor.empty(4, 4)
b = a*1 b = a*1
@@ -2444,7 +2433,7 @@ class TestUOpBecome(unittest.TestCase):
assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {}) assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {})
# tensors can become another realized tensor source # tensors can become another realized tensor source
@expect_rangeify_fails @unittest.expectedFailure
def test_become_existing_buf_simple(self): def test_become_existing_buf_simple(self):
a = Tensor.empty(4, 4) a = Tensor.empty(4, 4)
b = a+0 b = a+0
@@ -2453,14 +2442,14 @@ class TestUOpBecome(unittest.TestCase):
self.assertIs(a.uop, b.uop) self.assertIs(a.uop, b.uop)
# they can also chain other movement ops on top of the tensor source # they can also chain other movement ops on top of the tensor source
@expect_rangeify_fails @unittest.expectedFailure
def test_become_existing_buf_view(self): def test_become_existing_buf_view(self):
a = Tensor.empty(4, 4) a = Tensor.empty(4, 4)
b = a.permute((1, 0))+0 b = a.permute((1, 0))+0
check_schedule(b, 0) check_schedule(b, 0)
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st)
@expect_rangeify_fails @unittest.expectedFailure
def test_become_existing_buf_view_alt(self): def test_become_existing_buf_view_alt(self):
a = Tensor.empty(4, 4) a = Tensor.empty(4, 4)
b = a.permute((1, 0)).reshape((8, 2))+0 b = a.permute((1, 0)).reshape((8, 2))+0
@@ -2468,7 +2457,7 @@ class TestUOpBecome(unittest.TestCase):
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
# they can also have other base parents that simplified, in that case we just backtrack to the chained mops # they can also have other base parents that simplified, in that case we just backtrack to the chained mops
@expect_rangeify_fails @unittest.expectedFailure
def test_become_existing_buf_complex(self): def test_become_existing_buf_complex(self):
a = Tensor.empty(4, 4) a = Tensor.empty(4, 4)
b = (a.permute((1, 0))+0).reshape((8, 2))+0 b = (a.permute((1, 0))+0).reshape((8, 2))+0
@@ -2476,7 +2465,7 @@ class TestUOpBecome(unittest.TestCase):
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
assert b.uop.base.op is Ops.BUFFER assert b.uop.base.op is Ops.BUFFER
@expect_rangeify_fails @unittest.expectedFailure
def test_become_multiple_choices(self): def test_become_multiple_choices(self):
a = Tensor.empty(16) a = Tensor.empty(16)
b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0 b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
@@ -2494,13 +2483,8 @@ class TestUOpBecome(unittest.TestCase):
b.realize() b.realize()
assert a.uop.is_realized assert a.uop.is_realized
assert a.uop.buffer._base is None assert a.uop.buffer._base is None
# b is a subbuffer of a (buffer_view in non rangeify, rangeify just makes a shrink)
if RANGEIFY:
assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK) assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK)
assert b.uop.base is a.uop.base assert b.uop.base is a.uop.base
return
assert b.uop.op is Ops.BUFFER_VIEW
assert b.uop.src[0] is a.uop
def test_setitem_offset(self): def test_setitem_offset(self):
a = Tensor.full((16,), 0.).contiguous().realize() a = Tensor.full((16,), 0.).contiguous().realize()

View File

@@ -2,7 +2,6 @@ import unittest
from test.helpers import assert_jit_cache_len from test.helpers import assert_jit_cache_len
from tinygrad import Variable, Tensor, TinyJit from tinygrad import Variable, Tensor, TinyJit
from tinygrad.helpers import RANGEIFY
import numpy as np import numpy as np
class TestSymbolicJit(unittest.TestCase): class TestSymbolicJit(unittest.TestCase):
@@ -27,7 +26,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a[:, :vi]).numpy() symbolic = jf(a[:, :vi]).numpy()
expected = f(a[:, :i]).numpy() expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel? assert_jit_cache_len(jf, 1)
def test_add(self): def test_add(self):
def f(a, b): return (a+b).realize() def f(a, b): return (a+b).realize()
@@ -80,7 +79,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy() symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy()
expected = f(q, k[:, :i], v[:, :i]).numpy() expected = f(q, k[:, :i], v[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 4 if RANGEIFY else 5) assert_jit_cache_len(jf, 4)
def test_cat_dim0(self): def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize() def f(a, b): return a.cat(b, dim=0).realize()

View File

@@ -4,7 +4,7 @@ import torch
import unittest, copy, mmap, random, math, array import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _METADATA from tinygrad.tensor import _METADATA
from tinygrad.helpers import getenv, temp, mv_address, RANGEIFY from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
@@ -872,13 +872,6 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid") self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.uop.metadata[0].backward) self.assertTrue(y.grad.uop.metadata[0].backward)
si = Tensor.schedule(out, x.grad, y.grad)[-1] si = Tensor.schedule(out, x.grad, y.grad)[-1]
if not RANGEIFY:
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 2)
self.assertEqual(bw[0].name, "sigmoid")
else:
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}") self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"}) self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
bw = [m for m in si.metadata if m.backward] bw = [m for m in si.metadata if m.backward]

View File

@@ -1,6 +1,6 @@
import unittest import unittest
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.helpers import getenv, GlobalCounters, EMULATE, RANGEIFY from tinygrad.helpers import getenv, GlobalCounters, EMULATE
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program
from tinygrad.renderer import Estimates from tinygrad.renderer import Estimates
from tinygrad.codegen import full_rewrite from tinygrad.codegen import full_rewrite
@@ -51,11 +51,8 @@ class TestMemoryCount(unittest.TestCase):
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024) a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
b = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024) b = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
_, mem = get_stats(a+b) _, mem = get_stats(a+b)
if RANGEIFY:
# rangeify is smart! # rangeify is smart!
self.assertEqual(mem, 1024 + 2*1024) # 2 lil reads + 1 lil write self.assertEqual(mem, 1024 + 2*1024) # 2 lil reads + 1 lil write
else:
self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write
def test_self_add(self): def test_self_add(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8) a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)

View File

@@ -1,7 +1,6 @@
import unittest import unittest
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.helpers import RANGEIFY
class TestKernelize(unittest.TestCase): class TestKernelize(unittest.TestCase):
def test_add_reshaped(self): def test_add_reshaped(self):
@@ -18,8 +17,8 @@ class TestKernelize(unittest.TestCase):
a1 = a.sum(axis=1) a1 = a.sum(axis=1)
a0 = a1.sum(axis=0) a0 = a1.sum(axis=0)
a0.kernelize() a0.kernelize()
self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2 if RANGEIFY else 3) self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2)
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS if RANGEIFY else Ops.ASSIGN) self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
# input Tensor and user contiguous kernelize # input Tensor and user contiguous kernelize
self.assertIs(a0.uop.base.op, Ops.ASSIGN) self.assertIs(a0.uop.base.op, Ops.ASSIGN)
self.assertIs(a.uop.base.op, Ops.ASSIGN) self.assertIs(a.uop.base.op, Ops.ASSIGN)

View File

@@ -1,11 +1,11 @@
import unittest import unittest
import multiprocessing.shared_memory as shared_memory import multiprocessing.shared_memory as shared_memory
from tinygrad.helpers import CI, WIN, RANGEIFY from tinygrad.helpers import CI, WIN
from tinygrad.tensor import Tensor, Device from tinygrad.tensor import Tensor, Device
import numpy as np import numpy as np
class TestRawShmBuffer(unittest.TestCase): class TestRawShmBuffer(unittest.TestCase):
@unittest.skipIf(WIN and CI and RANGEIFY, "only fails with RANGEIFY on CI windows instance") @unittest.skipIf(WIN and CI, "only fails on CI windows instance")
def test_e2e(self): def test_e2e(self):
t = Tensor.randn(2, 2, 2).realize() t = Tensor.randn(2, 2, 2).realize()

View File

@@ -35,14 +35,14 @@ class TestWinograd(unittest.TestCase):
def test_forward_kernels(self): def test_forward_kernels(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
out = Tensor.conv2d(x,w) out = Tensor.conv2d(x,w)
self.assertEqual(len(out.schedule()), 2 if RANGEIFY else 4) self.assertEqual(len(out.schedule()), 2)
def test_backward_kernels(self): def test_backward_kernels(self):
x,w = Tensor.empty(1,4,9,9,requires_grad=True).realize(), Tensor.empty(4,4,3,3,requires_grad=True).realize() x,w = Tensor.empty(1,4,9,9,requires_grad=True).realize(), Tensor.empty(4,4,3,3,requires_grad=True).realize()
out = Tensor.conv2d(x,w, padding=1) out = Tensor.conv2d(x,w, padding=1)
out.mean().backward() out.mean().backward()
backward_schedule = Tensor.schedule(x.grad, w.grad) backward_schedule = Tensor.schedule(x.grad, w.grad)
self.assertEqual(len(backward_schedule), 4 if RANGEIFY else 9) self.assertEqual(len(backward_schedule), 4)
def test_counters(self): def test_counters(self):
IC, OC, X, Y = 4,4,9,9 IC, OC, X, Y = 4,4,9,9

View File

@@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, FUSE_ATTENTION
from tinygrad.helpers import suppress_finalizing from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient from tinygrad.gradient import compute_gradient
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, \ from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, \
@@ -227,7 +227,7 @@ class Tensor(MathTrait):
# verify Tensors match the spec # verify Tensors match the spec
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec) if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
if RANGEIFY and any(isinstance(x._device, tuple) for x in big_sink.toposort()): if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst])) big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))