mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
seems logical
This commit is contained in:
@@ -182,7 +182,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
assert Tensor.randn(10, 10).numel() == 100
|
||||
assert Tensor.randn(1,2,5).numel() == 10
|
||||
assert Tensor.randn(1,1,1,1,1,1).numel() == 1
|
||||
# assert Tensor([]).numel() == 0 # TODO: fix empty buffers
|
||||
assert Tensor([]).numel() == 0
|
||||
# assert Tensor.randn(1,0,2,5) == 0 # TODO: fix empty tensors
|
||||
|
||||
def test_element_size(self):
|
||||
|
||||
@@ -25,6 +25,7 @@ class RawBufferCopyIn(RawBuffer):
|
||||
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray, **kwargs):
|
||||
if x.size == 0: return EmptyBuffer.fromCPU(x)
|
||||
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
|
||||
ret._copyin(x)
|
||||
return ret
|
||||
@@ -36,7 +37,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
||||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
@@ -49,3 +50,8 @@ class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
|
||||
class RawConst(RawBuffer): # pylint: disable=abstract-method
|
||||
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
|
||||
|
||||
class EmptyBuffer(RawBuffer):
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray): return cls(0, dtypes.from_np(x.dtype))
|
||||
def toCPU(self) -> np.ndarray: return np.empty(0, dtype=self.dtype.np)
|
||||
|
||||
Reference in New Issue
Block a user