From 67f4e03724f9c9291e660d38f97dfac7eb8772e3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 1 Dec 2023 18:29:06 -0500 Subject: [PATCH] rewrite 0 size loadop into a CONST (#2556) * rewrite 0 size loadop into a CONST * check alloc size * EMPTY is better * Revert "EMPTY is better" This reverts commit 574fe0f9ed28f1b97da5a81afdfd2cd5d9a94ff9. * no ast is created * fix test --- test/models/test_real_world.py | 5 +++-- test/test_schedule.py | 5 +++++ tinygrad/device.py | 4 +++- tinygrad/lazy.py | 5 ++++- tinygrad/realize.py | 18 ++++++++---------- tinygrad/runtime/ops_cuda.py | 4 +--- tinygrad/runtime/ops_gpu.py | 1 - tinygrad/runtime/ops_hip.py | 1 - tinygrad/runtime/ops_metal.py | 1 - tinygrad/tensor.py | 1 + 10 files changed, 25 insertions(+), 20 deletions(-) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index f2e9a9aa9b..5afb06bcac 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -6,6 +6,7 @@ from tinygrad.nn.state import get_parameters from tinygrad.jit import TinyJit from tinygrad import Device, GlobalCounters from tinygrad.helpers import CI, dtypes +from tinygrad.shape.symbolic import Variable from test.helpers import derandomize_model from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS @@ -67,8 +68,8 @@ class TestRealWorld(unittest.TestCase): model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"])) derandomize_model(model) @TinyJit - def test(t): return model(t, 0).realize() - helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True) + def test(t, v): return model(t, v).realize() + helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 180 if CI else 516, all_jitted=True) @unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CLANG", "CPU"] or not CI), "needs JIT, too long on CI LLVM and CLANG") def test_train_cifar(self): diff --git a/test/test_schedule.py b/test/test_schedule.py index 643e6285d9..952f996749 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -332,5 +332,10 @@ class TestSchedule(unittest.TestCase): out = x ** Tensor(2) check_schedule(out, 1) + def test_zero_size(self): + x = Tensor.rand(2, 3, 0) + out = x + 1 + check_schedule(out, 0, filter_loadops=False) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/device.py b/tinygrad/device.py index 0e200d6892..75aa016046 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -83,7 +83,9 @@ class Buffer: # TODO: size, dest, src are the same type. can we enforce this? class Allocator: - def alloc(self, size:int, dtype:DType): return self._alloc(size, dtype) + def alloc(self, size:int, dtype:DType): + assert size > 0, f"alloc size must be positve, getting {size}" + return self._alloc(size, dtype) def _alloc(self, size:int, dtype:DType): raise NotImplementedError("need alloc") def free(self, opaque, size:int, dtype:DType): self._free(opaque) # if you are returning a Python object, you don't need a free def _free(self, opaque): pass diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index e7fd20307f..0c9911572c 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -82,6 +82,9 @@ def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg lazycache: WeakValueDictionary = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None): + # rewrite 0 size into a CONST + if 0 in st.shape: return LazyBuffer(device, ShapeTracker.from_shape(st.shape), LoadOps, LazyOp(LoadOps.CONST, tuple(), 0.0), dtype) + # fromcpu aren't cached if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base) @@ -183,7 +186,7 @@ class LazyBuffer: @staticmethod def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Optional[LazyBuffer]=None) -> LazyBuffer: - return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) + return create_lazybuffer(device, ShapeTracker.from_shape(shape), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) # create a constant with the shape and dtype of self def const(self, val:Union[float, int]) -> LazyBuffer: diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 6989d6e7a0..6afbd1f204 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -27,16 +27,14 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False): # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \ Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype) - # TODO: size 0 should be removed from the schedule - if si.out.realized.size != 0: - if si.ast.op in LoadOps: - if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}") - # confirm the LoadOps are contiguous and in order - for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" - kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {} - LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs) - else: - Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals) + if si.ast.op in LoadOps: + if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}") + # confirm the LoadOps are contiguous and in order + for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" + kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {} + LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs) + else: + Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals) del si.out.op for v in si.out.views: del v.op #assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}" diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 05d5b70aa6..5a08eca0e1 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -46,9 +46,7 @@ class CUDAProgram: return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) class CUDAAllocator(LRUAllocator): - def _alloc(self, size, dtype): - if size == 0: return None - return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size * dtype.itemsize))) + def _alloc(self, size, dtype): return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size * dtype.itemsize))) def _free(self, opaque): check(cuda.cuMemFree_v2(opaque)) def copyin(self, dest, src:memoryview): check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None)) def copyout(self, dest:memoryview, src): check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest))) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 1cc2d0338b..6b6d832edd 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -59,7 +59,6 @@ class CLAllocator(LRUAllocator): self.device = device super().__init__() def _alloc(self, size:int, dtype:DType): - if size == 0: return None if isinstance(dtype, ImageDType): # NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}" diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 6c54f734f1..b5aca40d51 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -49,7 +49,6 @@ class HIPAllocator(LRUAllocator): self.device = device super().__init__() def _alloc(self, size: int, dtype: DType): - if size == 0: return None check(hip.hipSetDevice(self.device)) return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size * dtype.itemsize))) def _free(self, opaque:T): check(hip.hipFree(opaque)) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 3fb4884d05..d6dc530469 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -52,7 +52,6 @@ class MetalAllocator(LRUAllocator): self.device:MetalDevice = device super().__init__() def _alloc(self, size:int, dtype:DType) -> Any: - if size == 0: return None ret = self.device.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared) if ret is None: raise MemoryError(f"Metal OOM while allocating {size=} {dtype=}") return ret diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7e1214eb53..702c6ede65 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -123,6 +123,7 @@ class Tensor: def numpy(self) -> np.ndarray: assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" + if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np) return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) def item(self) -> Union[float, int]: return self.numpy().item()