From b629a7998decaaad25ba65ce336f4c7dab15993d Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 26 Sep 2024 08:24:26 +0800 Subject: [PATCH] early assert buffer count limit [run_process_replay] (#6746) * better error message for buffer count limit [run_process_replay] * 3.9 needs that * assert ScheduleItem * new _test_buf_cnt --- test/external/external_model_benchmark.py | 9 ++++---- test/test_schedule.py | 28 ++++++++++------------- tinygrad/engine/schedule.py | 6 +++-- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 6b6c1b0299..762535a381 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -1,6 +1,5 @@ import csv, pathlib, time, numpy as np from os import getenv -from tinygrad.device import CompileError import torch torch.set_num_threads(1) import onnx @@ -73,10 +72,10 @@ def benchmark_model(m, devices, validate_outs=False): for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()} benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821 del inputs, tinygrad_model, tinygrad_jitted_model - except CompileError as e: - # METAL fails with buffer count limit - if m == "dm" and device == "METAL": return - raise e + except RuntimeError as e: + # TODO: we don't run the dm model on METAL for now + if Device.DEFAULT == "METAL": assert "buffer count limit" in str(e) + else: raise e # convert model to torch try: diff --git a/test/test_schedule.py b/test/test_schedule.py index 86f295acbb..6c7e5036d8 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -10,7 +10,6 @@ from typing import List, Optional, Union, cast from tinygrad import nn, dtypes from tinygrad.device import Device from tinygrad.dtype import DType, PtrDType -from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.tensor import Tensor @@ -72,18 +71,6 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) -def _test_buf_cnt(cnt:int, buf_max:int, allowed:int): - backup_renderer = Device[Device.DEFAULT].renderer - r = CStyleLanguage() - r.buf_max = buf_max - alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)]) - s = alu.schedule() - assert len(s) == allowed - Device[Device.DEFAULT].renderer = backup_renderer - run_schedule(s) - expected = functools.reduce(lambda x,y: x+y, [np.ones((1, 1)) for _ in range(cnt-1)]) - np.testing.assert_equal(alu.numpy(), expected) - class TestSchedule(unittest.TestCase): def test_basic_binop_fusion(self): a = Tensor.empty(10) @@ -1326,11 +1313,20 @@ class TestSchedule(unittest.TestCase): @unittest.expectedFailure def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(6, FUSE_CONV_BW=1, AST_REWRITE=1, dtype=dtypes.half) - def test_buf_cnt_at_limit(self): _test_buf_cnt(5, buf_max=5, allowed=1) + def _test_buf_cnt(self, cnt:int, allowed:int): + if (m:=Device[Device.DEFAULT].renderer.buf_max) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}") + alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)]) + s = alu.schedule() + assert len(s) == allowed + run_schedule(s) + expected = functools.reduce(lambda x,y: x+y, [np.ones((1, 1)) for _ in range(cnt-1)]) + np.testing.assert_equal(alu.numpy(), expected) + + def test_buf_cnt_at_limit(self): self._test_buf_cnt(31, allowed=1) @unittest.expectedFailure - def test_buf_cnt_over_limit(self): _test_buf_cnt(7, buf_max=5, allowed=2) + def test_buf_cnt_over_limit(self): self._test_buf_cnt(32, allowed=2) @unittest.expectedFailure - def test_buf_cnt_over_limit_alt(self): _test_buf_cnt(11, buf_max=5, allowed=3) + def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3) class TestIndexing(unittest.TestCase): def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 323334e64b..174fd62a30 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -10,7 +10,7 @@ from tinygrad.shape.symbolic import Variable, sint from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes from tinygrad.lazy import LazyBuffer from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.device import Buffer +from tinygrad.device import Buffer, Device from tinygrad.shape.view import View, strides_for_shape # creation can recurse a lot @@ -412,7 +412,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] kernel_number += 1 for out in lsi.outputs: realized_lazybuffer(out, kernel_number) for out in lsi.outputs: del out.srcs # can only schedule once - schedule.append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata)) + schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata)) + if (m:=Device[(device:=si.outputs[0].device)].renderer.buf_max) and len(si.bufs) >= m: + raise RuntimeError(f"{si} exceeded the buffer count limit for {device}: {len(si.bufs)} >= {m}") for x in graph[lsi]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x)