Revert "bug in metal: offset is stored as uint32, overflow (#15129)" (#15136)

This reverts commit 9c58db16fa.
This commit is contained in:
chenyu
2026-03-04 16:54:42 -05:00
committed by GitHub
parent 9c58db16fa
commit 34594bcaaf
2 changed files with 1 additions and 32 deletions

View File

@@ -1,26 +0,0 @@
import unittest
from unittest.mock import MagicMock
from tinygrad import Device
from tinygrad.engine.realize import CompiledRunner
def _ei(*offsets):
ei = MagicMock()
ei.prg = MagicMock(spec=CompiledRunner)
ei.bufs = [None if o is None else MagicMock(**{"_buf.offset": o}) for o in offsets]
return ei
@unittest.skipUnless(Device.DEFAULT == "METAL", "Metal device required to run")
class TestMetalGraph(unittest.TestCase):
def setUp(self):
from tinygrad.runtime.graph.metal import MetalGraph
self.MetalGraph = MetalGraph
self.dev = Device[Device.DEFAULT]
def test_supports_exec_item_normal_offset(self):
assert self.MetalGraph.supports_exec_item([self.dev], _ei(0, 100, 0xFFFFFFFF)) is True
def test_supports_exec_item_overflow_offset(self):
assert self.MetalGraph.supports_exec_item([], _ei(0, 0x100000000)) is False
if __name__ == "__main__":
unittest.main()

View File

@@ -13,6 +13,7 @@ class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
orig_valid_positions: dict[int, set[int]]|None = None):
super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
# create metal batch exec
icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new()
@@ -108,9 +109,3 @@ class MetalGraph(GraphRunner):
if PROFILE and self.command_buffer is not None:
wait_check(self.command_buffer)
self.collect_timestamps()
@staticmethod
def supports_exec_item(devs, ei:ExecItem) -> bool:
# Metal ICB replay encodes offsets as uint32; reject if any buffer offset exceeds 32-bit range.
if any(b is not None and b._buf.offset > 0xFFFFFFFF for b in ei.bufs): return False
return GraphRunner.supports_exec_item(devs, ei)