mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Disktensors! (#819)
* make empty a real thing * start ops_disk * disk tensor works * interpreted cleanup * slice write to disk * preprocess imagenet * fix custom function
This commit is contained in:
@@ -31,7 +31,7 @@ CL = _CL()
|
||||
|
||||
# TODO: merge CLImage in here
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype, device=0):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
if isinstance(dtype, ImageDType):
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
@@ -39,7 +39,7 @@ class CLBuffer(RawBufferCopyInOut):
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
setattr(buf, 'device', device) # device is tracked on the underlying buffer
|
||||
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
|
||||
super().__init__(size, dtype, buf)
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
|
||||
Reference in New Issue
Block a user