diff --git a/test/test_tensor.py b/test/test_tensor.py index ee2eb3ef9d..7e9e997348 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -182,7 +182,7 @@ class TestTinygrad(unittest.TestCase): assert Tensor.randn(10, 10).numel() == 100 assert Tensor.randn(1,2,5).numel() == 10 assert Tensor.randn(1,1,1,1,1,1).numel() == 1 - # assert Tensor([]).numel() == 0 # TODO: fix empty buffers + assert Tensor([]).numel() == 0 # assert Tensor.randn(1,0,2,5) == 0 # TODO: fix empty tensors def test_element_size(self): diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index d29d81b006..1209315f27 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -25,6 +25,7 @@ class RawBufferCopyIn(RawBuffer): @classmethod def fromCPU(cls, x:np.ndarray, **kwargs): + if x.size == 0: return EmptyBuffer.fromCPU(x) ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs) ret._copyin(x) return ret @@ -36,7 +37,7 @@ class RawBufferMapped(RawBufferCopyIn): # this one is simple enough that i moved it out of the runtimes class RawMallocBuffer(RawBufferMapped): - def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int64: ctypes.c_int64}[dtype] * size)()) + def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)()) def _buffer(self): return memoryview(self._buf) class RawBufferCopyInOut(RawBufferCopyIn): @@ -49,3 +50,8 @@ class RawBufferCopyInOut(RawBufferCopyIn): class RawConst(RawBuffer): # pylint: disable=abstract-method def __repr__(self): return f"const<{self._buf}, {self.dtype}>" + +class EmptyBuffer(RawBuffer): + @classmethod + def fromCPU(cls, x:np.ndarray): return cls(0, dtypes.from_np(x.dtype)) + def toCPU(self) -> np.ndarray: return np.empty(0, dtype=self.dtype.np)