Disktensors! (#819)

* make empty a real thing

* start ops_disk

* disk tensor works

* interpreted cleanup

* slice write to disk

* preprocess imagenet

* fix custom function
This commit is contained in:
George Hotz
2023-05-28 15:40:37 -07:00
committed by GitHub
parent 6d49925a26
commit 59f9bcd4a4
13 changed files with 151 additions and 26 deletions

2
.gitignore vendored
View File

@@ -24,3 +24,5 @@ datasets/cifar-10-python.tar.gz
datasets/librispeech/
datasets/imagenet/
datasets/squad/
datasets/img_align_celeba*
datasets/open-images-v6-mlperf

View File

@@ -0,0 +1,22 @@
from tinygrad.tensor import Tensor
from datasets.imagenet import iterate, get_val_files
if __name__ == "__main__":
#sz = len(get_val_files())
sz = 32*100
X,Y = None, None
idx = 0
for x,y in iterate(shuffle=False):
print(x.shape, y.shape)
assert x.shape[0] == y.shape[0]
bs = x.shape[0]
if X is None:
# TODO: need uint8 support
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x")
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y")
print(X.shape, Y.shape)
X[idx:idx+bs].assign(x)
Y[idx:idx+bs].assign(y)
idx += bs
if idx >= sz: break

View File

@@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
return ret.realized
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
return Device[ret.device].buffer(np.arctan2(a.realized._buf, b.realized._buf))
return Device[ret.device].from_underlying(np.arctan2(a.realized._buf, b.realized._buf))
# *** second, we write the ATan2 mlop ***
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative

View File

@@ -56,6 +56,8 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
class TestOps(unittest.TestCase):
def test_zeros(self):
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
def test_empty_0(self):
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)
def test_ones(self):
helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True)
def test_eye(self):

View File

@@ -0,0 +1,56 @@
import pathlib
import unittest
import numpy as np
from tinygrad.tensor import Tensor
class TestDiskTensor(unittest.TestCase):
def test_empty(self):
pathlib.Path("/tmp/dt1").unlink(missing_ok=True)
Tensor.empty(100, 100, device="disk:/tmp/dt1")
def test_write_ones(self):
pathlib.Path("/tmp/dt2").unlink(missing_ok=True)
out = Tensor.ones(10, 10, device="CPU")
outdisk = out.to("disk:/tmp/dt2")
print(outdisk)
outdisk.realize()
del out, outdisk
# test file
with open("/tmp/dt2", "rb") as f:
assert f.read() == b"\x00\x00\x80\x3F" * 100
# test load alt
reloaded = Tensor.empty(10, 10, device="disk:/tmp/dt2")
out = reloaded.numpy()
assert np.all(out == 1.)
def test_slice(self):
pathlib.Path("/tmp/dt3").unlink(missing_ok=True)
Tensor.arange(10, device="disk:/tmp/dt3").realize()
slice_me = Tensor.empty(10, device="disk:/tmp/dt3")
print(slice_me)
is_3 = slice_me[3:4].cpu()
assert is_3.numpy()[0] == 3
def test_slice_2d(self):
pathlib.Path("/tmp/dt5").unlink(missing_ok=True)
Tensor.arange(100, device="CPU").to("disk:/tmp/dt5").realize()
slice_me = Tensor.empty(10, 10, device="disk:/tmp/dt5")
tst = slice_me[1].numpy()
print(tst)
np.testing.assert_allclose(tst, np.arange(10, 20))
def test_assign_slice(self):
pathlib.Path("/tmp/dt4").unlink(missing_ok=True)
cc = Tensor.arange(10, device="CPU").to("disk:/tmp/dt4").realize()
#cc.assign(np.ones(10)).realize()
print(cc[3:5].numpy())
cc[3:5].assign([13, 12]).realize()
print(cc.numpy())
if __name__ == "__main__":
unittest.main()

View File

@@ -8,6 +8,7 @@ def multidevice_test(fxn):
exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
def ret(self):
for device in Device._buffers:
if device == "DISK": continue
print(device)
if device in exclude_devices:
print(f"WARNING: {device} test is excluded")

View File

@@ -102,7 +102,7 @@ class LazyBuffer:
if GRAPH >= 3: log_op(self, self.op, phantom=True)
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else self.realized} st:{self.st}>"
def _device_extra_args(self) -> Dict[str, int]: return {"device": int(self.device.split(":")[1])} if ":" in self.device else {}
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":")[1]} if ":" in self.device else {}
def realize(self:LazyBuffer) -> LazyBuffer:
if self.realized is None:
@@ -113,6 +113,8 @@ class LazyBuffer:
else:
if DEBUG >= 4: print(f"copying {self.op.arg.shape}:{dtypes.from_np(self.op.arg.dtype)} -> {self.device}")
self.realized = Device[self.device].buffer.fromCPU(self.op.arg(), **self._device_extra_args())
elif self.op.op == LoadOps.EMPTY:
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
elif self.op.op == LoadOps.CONTIGUOUS:
realized = self.op.src[0].realize().realized
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
@@ -161,6 +163,10 @@ class LazyBuffer:
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
return create_lazybuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x), dtypes.from_np(x.dtype))
@staticmethod
def empty(shape, dtype, device) -> LazyBuffer:
return create_lazybuffer(device, shape, LoadOps, LazyOp(LoadOps.EMPTY, tuple()), dtype)
# create a constant with the shape and dtype of self
def const_like(self, val) -> LazyBuffer:
return create_lazybuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \

View File

@@ -12,7 +12,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto();
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702
class LoadOps(Enum): FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]
@@ -34,24 +34,25 @@ def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp:
# **************** for Interpreted Buffers ****************
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf):
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf, from_underlying=None):
self.buffer = buffer
self.fxn_for_op = fxn_for_op
self.from_lazybuffer = from_lazybuffer
self.from_underlying = buffer if from_underlying is None else from_underlying
self.to_underlying = to_underlying
self.synchronize = lambda: None
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, context=None):
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
created_context = context is None
if context is None: context = dict()
if not created_context and ast in context: return context[ast]
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
if DEBUG >= 3: st = time.perf_counter()
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
if not created_context: context[ast] = ret
if output is not None and output.output_buffer is not None:
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
@@ -151,7 +152,7 @@ class Compiled:
k = self.codegen(ast, output)
# this is the default now
if getenv("ENABLE_METHOD_CACHE", 1):
if hasattr(k, 'key') and getenv("ENABLE_METHOD_CACHE", 1):
if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime)
elif DEBUG >= 5: print(f"method cache hit : {k.key}")
prg = self.method_cache[k.key]

View File

@@ -1,7 +1,7 @@
import numpy as np
import operator
from typing import Callable, Dict, Tuple
from tinygrad.helpers import dtypes
from typing import Callable, Dict, Tuple, Optional
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
@@ -35,8 +35,8 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
}}
class RawNumpyBuffer(RawBuffer):
def __init__(self, buf:np.ndarray): super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
@classmethod
def fromCPU(cls, x): return cls(x)
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)

View File

@@ -0,0 +1,25 @@
import os, mmap
from typing import Optional
from typing import Callable, Dict
from tinygrad.helpers import prod
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
class RawDiskBuffer(RawBufferMapped):
def __init__(self, size, dtype, device:Optional[str]=None, buf=None, shape=None):
self.shape = (size, ) if shape is None else shape
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
if device is not None:
with open(device, "a+b") as f:
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = memoryview(mmap.mmap(f.fileno(), size * dtype.itemsize))
super().__init__(size, dtype, buf)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buffer(), shape=arg)
def shrink(self, arg):
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
return RawDiskBuffer(arg[0][1]-arg[0][0], self.dtype, buf=self._buffer()[arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize:arg[0][1]*prod(self.shape[1:])*self.dtype.itemsize])
def _buffer(self): return self._buf
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, to_underlying=lambda x:x, from_underlying=lambda x:x)

View File

@@ -31,7 +31,7 @@ CL = _CL()
# TODO: merge CLImage in here
class CLBuffer(RawBufferCopyInOut):
def __init__(self, size, dtype, device=0):
def __init__(self, size, dtype, device='0'):
if isinstance(dtype, ImageDType):
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
@@ -39,7 +39,7 @@ class CLBuffer(RawBufferCopyInOut):
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
else:
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', device) # device is tracked on the underlying buffer
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
super().__init__(size, dtype, buf)
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"

View File

@@ -1,12 +1,13 @@
import torch
from typing import Dict, Callable
from typing import Dict, Callable, Optional
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, Op, Interpreted
from tinygrad.helpers import getenv, dtypes, prod
from tinygrad.helpers import getenv, dtypes, prod, DType
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.lib import RawBuffer
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int32: dtypes.int32, torch.int64: dtypes.int64}
inverse_type_map = {v:k for k,v in type_map.items()}
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
@@ -18,8 +19,10 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
}}
class RawTorchBuffer(RawBuffer):
def __init__(self, buf:torch.Tensor): super().__init__(prod(buf.shape), type_map[buf.dtype], buf)
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None): super().__init__(size, dtype, buf if buf is not None else torch.empty([size], dtype=inverse_type_map[dtype]))
@classmethod
def fromCPU(cls, x): return cls(torch.from_numpy(x).requires_grad_(False).to(device))
def fromCPU(cls, x):
buf = torch.from_numpy(x).requires_grad_(False).to(device)
return cls(prod(x.shape), type_map[buf.dtype], buf)
def toCPU(self): return self._buf.cpu().numpy()
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, from_underlying=lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))

View File

@@ -34,7 +34,7 @@ class Tensor:
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
device = device.upper().replace(":0", "") # canonicalize device
device = (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # canonicalize device
if isinstance(data, list):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
elif isinstance(data, LazyBuffer) and data.device != device:
@@ -90,8 +90,13 @@ class Tensor:
return self
def assign(self, x) -> Tensor:
if not isinstance(x, Tensor): x = Tensor(x)
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
# TODO: this is a hack for writing to DISK
if self.device.startswith("DISK"):
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
return self
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
@@ -126,7 +131,9 @@ class Tensor:
def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor.zeros(*shape, **kwargs)
def empty(*shape, device=Device.DEFAULT, dtype:Optional[DType]=None, **kwargs):
# NOTE: we do the reshape to fix interpreted buffers
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, device), dtype=dtype, device=device, **kwargs).reshape(*shape)
@staticmethod
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
@@ -457,7 +464,7 @@ class Tensor:
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse else self.mul(1/x)
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x)
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)