delete forced_realize (#8615)

* delete forced_realize

* put that back

* expectedFailures

* cleaner create_subbuffer

* more comments

---------

Co-authored-by: qazal <qazal.software@gmail.com>
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
George Hotz
2025-01-20 09:40:36 -08:00
committed by GitHub
parent 679b1ad058
commit 46a8c5e1e5
10 changed files with 21 additions and 19 deletions

View File

@@ -298,7 +298,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx

View File

@@ -166,7 +166,7 @@ class TestIndexing(unittest.TestCase):
GlobalCounters.reset()
z = emb(x).realize()
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(GlobalCounters.kernel_count, 3)
if getenv("CHECK", 1):
import torch
with torch.no_grad():

View File

@@ -220,7 +220,9 @@ class TestMultiConstFolding(unittest.TestCase):
t = Tensor.arange(16).float().realize().to(ds)
# non const folding case creates one ast on each shard
_check_ast_count(4, t + 1)
# NOTE: there's extra contiguous kernels here since it's realizing both the CONTIGUOUS and its parent COPY
# why does multi call contiguous on a COPY?
_check_ast_count(7, t + 1)
_check_ast_count(4, 1 + t)
_check_ast_count(4, t * 2)
_check_ast_count(4, 2 * t)

View File

@@ -318,6 +318,7 @@ class TestJit(unittest.TestCase):
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
@unittest.expectedFailure # requires contiguous folding
def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()

View File

@@ -63,7 +63,11 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
a, b = Tensor.randn(4), Tensor.randn(4)
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
# without contiguous folding, rand.to("CLANG") and rand.contiguous().to("CLANG") are different UOps.
# this test asserts they are the identical Buffer
# having different buffers is fine for correctness, because the outputs match.
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
np_a, np_b = a.numpy(), b.numpy()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = list(lower_schedule(c.schedule()))
@@ -1690,6 +1694,7 @@ class TestHandCodedOpts(unittest.TestCase):
# should upcast the two Tensor.stacks
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
@unittest.expectedFailure # requires contiguous folding
def test_masked_upcast_wino_full(self):
with Context(WINO=1):
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()

View File

@@ -734,7 +734,7 @@ class TestMultiTensor(unittest.TestCase):
zeros = Tensor.zeros(3).realize()
b = a.to(devices_2)*zeros.to(devices_2)
sched = b.schedule()
self.assertEqual(len(sched), 6)
self.assertEqual(len(sched), 8)
# notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy
self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2)
# all these kernels are just because multi calls contiguous on every single shard

View File

@@ -69,7 +69,8 @@ class TestSetitem(unittest.TestCase):
t[1] ^= 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
@unittest.expectedFailure
#@unittest.expectedFailure
# update: passing after delete_forced_realize
def test_setitem_consecutive_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2

View File

@@ -104,7 +104,8 @@ class TestRealizeMeansRealize(unittest.TestCase):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
self.assertEqual(x.lazydata.op, Ops.VIEW)
@unittest.expectedFailure
#@unittest.expectedFailure
# update: passing after delete_forced_realize
def test_uniform_realizes(self):
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
print(x.lazydata)

View File

@@ -109,7 +109,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
# track the underlying tensor uop for this buffer
ctx.tensor_uops[buf_uop] = [buf]
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
# **** AST graph rewrite
@@ -329,7 +329,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
# maybe fuse arange with its children
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue
if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del ctx.realizes[tr]
@@ -448,8 +448,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs)
return x.view(unwrap(view.st))
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if not root.device.startswith("DISK"): return None
if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone
if not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))

View File

@@ -233,7 +233,6 @@ class UOpMetaClass(type):
# some uops map to other stuff
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
@@ -409,11 +408,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self):
if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST:
return self.alu(Ops.CONTIGUOUS)
forced_realize.add(self.base)
return self
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
# *** from LazyBuffer ***
@@ -443,8 +438,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def lbs(self): return [self]
@property
def metadata(self): return all_metadata.get(self, None)
@property
def forced_realize(self): return self in forced_realize
# *** uop movement ops ***