diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 47cffd0090..be82930b62 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -298,7 +298,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model kernel count and gate usage run: | - PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx + PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - if: ${{ matrix.task == 'optimage' }} name: Test openpilot alt model correctness (float32) run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx diff --git a/test/test_arange.py b/test/test_arange.py index a5c8b535bb..07512ae1b6 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -166,7 +166,7 @@ class TestIndexing(unittest.TestCase): GlobalCounters.reset() z = emb(x).realize() self.assertLessEqual(GlobalCounters.global_ops, op_limit) - self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.kernel_count, 3) if getenv("CHECK", 1): import torch with torch.no_grad(): diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 4ca2359912..dfffca8989 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -220,7 +220,9 @@ class TestMultiConstFolding(unittest.TestCase): t = Tensor.arange(16).float().realize().to(ds) # non const folding case creates one ast on each shard - _check_ast_count(4, t + 1) + # NOTE: there's extra contiguous kernels here since it's realizing both the CONTIGUOUS and its parent COPY + # why does multi call contiguous on a COPY? + _check_ast_count(7, t + 1) _check_ast_count(4, 1 + t) _check_ast_count(4, t * 2) _check_ast_count(4, 2 * t) diff --git a/test/test_jit.py b/test/test_jit.py index 382c83a52a..7abb13100f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -318,6 +318,7 @@ class TestJit(unittest.TestCase): assert len(res3) == 10, "All values should be different, rand works in jit." assert res3 != res2, "Jit rand is diff with diff seeds" + @unittest.expectedFailure # requires contiguous folding def test_jit_random_after_unrealized_random(self): @TinyJit def f(): return Tensor.rand() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 58bf3d4e13..da2995d37d 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -63,7 +63,11 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): - a, b = Tensor.randn(4), Tensor.randn(4) + # NOTE: this realize exists because Tensor.numpy calls .contiguous() internally + # without contiguous folding, rand.to("CLANG") and rand.contiguous().to("CLANG") are different UOps. + # this test asserts they are the identical Buffer + # having different buffers is fine for correctness, because the outputs match. + a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize() np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) lowered = list(lower_schedule(c.schedule())) @@ -1690,6 +1694,7 @@ class TestHandCodedOpts(unittest.TestCase): # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 + @unittest.expectedFailure # requires contiguous folding def test_masked_upcast_wino_full(self): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 34a3480c0d..b34baced75 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -734,7 +734,7 @@ class TestMultiTensor(unittest.TestCase): zeros = Tensor.zeros(3).realize() b = a.to(devices_2)*zeros.to(devices_2) sched = b.schedule() - self.assertEqual(len(sched), 6) + self.assertEqual(len(sched), 8) # notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2) # all these kernels are just because multi calls contiguous on every single shard diff --git a/test/test_setitem.py b/test/test_setitem.py index f1bb595ef2..5c7c14fb57 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -69,7 +69,8 @@ class TestSetitem(unittest.TestCase): t[1] ^= 5 np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]]) - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after delete_forced_realize def test_setitem_consecutive_inplace_operator(self): t = Tensor.arange(4).reshape(2, 2).contiguous() t[1] += 2 diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index b36b81f243..a9a41eace0 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -104,7 +104,8 @@ class TestRealizeMeansRealize(unittest.TestCase): x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize() self.assertEqual(x.lazydata.op, Ops.VIEW) - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after delete_forced_realize def test_uniform_realizes(self): x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize() print(x.lazydata) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2375408dd2..9cbc371158 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -109,7 +109,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: # track the underlying tensor uop for this buffer ctx.tensor_uops[buf_uop] = [buf] # (early) bufferize - cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st) + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret # **** AST graph rewrite @@ -329,7 +329,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # maybe fuse arange with its children for rbuf in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} - if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue + if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: del ctx.realizes[tr] @@ -448,8 +448,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) return x.view(unwrap(view.st)) def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if not root.device.startswith("DISK"): return None - if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone + if not b.device.startswith("DISK"): return None buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 29fe063540..639d16819f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -233,7 +233,6 @@ class UOpMetaClass(type): # some uops map to other stuff buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary() -forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet() # NOTE: this should be frozen, but frozen is slower @dataclass(eq=False, slots=True) @@ -409,11 +408,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}") return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) - def contiguous(self): - if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST: - return self.alu(Ops.CONTIGUOUS) - forced_realize.add(self.base) - return self + def contiguous(self): return self.alu(Ops.CONTIGUOUS) # *** from LazyBuffer *** @@ -443,8 +438,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def lbs(self): return [self] @property def metadata(self): return all_metadata.get(self, None) - @property - def forced_realize(self): return self in forced_realize # *** uop movement ops ***