Refactor LoadOps (#910)

* test

* work

* upd test

* loadops

* cleanups

* real ones

* remove LazyNumpyArray

* fix assign test

* remove range

* np.require

* llama uses arange kernels

* no caching consts

* fix enet

* torch load support

* tests cleanup

* fix shufflenet

* fix image

* fix torch_load test
This commit is contained in:
George Hotz
2023-06-03 09:40:43 -07:00
committed by GitHub
parent d58586bb17
commit 791530045d
20 changed files with 254 additions and 117 deletions

View File

@@ -134,7 +134,7 @@ assert len(lazyop.src) == 2
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
print(lazyop.src[0].op)
assert lazyop.src[0].op.op == LoadOps.FROMCPU
assert lazyop.src[0].op.arg.fxn == [2], "the arg of the FROMCPU LazyOP is the [2.]"
assert lazyop.src[0].op.arg == [2], "the arg of the FROMCPU LazyOP is the [2.]"
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
# now we realize the LazyBuffer

1
extra/disk/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
a.out

71
extra/disk/test.cc Normal file
View File

@@ -0,0 +1,71 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/mman.h>
#include <thread>
#include <chrono>
#define SZ (unsigned long long)(512*1024*1024)
#define CNT 10LL
void test_read() {
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
printf("open %d\n", f);
/*void *buf = malloc(CNT*SZ);
printf("malloc %p\n", buf);
mlock(buf, CNT*SZ);*/
auto t0 = std::chrono::high_resolution_clock::now();
void *buf = mmap(NULL, SZ*CNT, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
auto t1 = std::chrono::high_resolution_clock::now();
mlock(buf, CNT*SZ);
for (int i = 0; i < CNT; i++) {
read(f, (unsigned char*)buf+SZ*i, SZ);
}
auto t2 = std::chrono::high_resolution_clock::now();
//free(buf);
printf("malloc %p\n", buf);
float ns = (float)std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
float pns = (float)std::chrono::duration_cast<std::chrono::nanoseconds>(t1-t0).count();
printf("read %.2f GB in %.2f s (%.2f s to prepare), %.2f GB/s\n", SZ/1e9*CNT, ns*1e-9, pns*1e-9, (SZ*CNT)/ns);
close(f);
munmap(buf, SZ*CNT);
}
void test_mmap() {
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
printf("open %d\n", f);
void *dat = mmap(NULL, SZ*CNT, PROT_READ, MAP_PRIVATE, f, 0);
auto t1 = std::chrono::high_resolution_clock::now();
mlock(dat, SZ*CNT);
auto t2 = std::chrono::high_resolution_clock::now();
printf("mmap %p\n", dat);
float ns = (float)std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
printf("read %.2f GB in %.2f s, %.2f GB/s\n", SZ/1e9*CNT, ns*1e-9, (SZ*CNT)/ns);
close(f);
munlock(dat, SZ*CNT);
munmap(dat, SZ*CNT);
}
int main() {
system("sync; echo 1 > /proc/sys/vm/drop_caches");
test_mmap();
//system("sync; echo 1 > /proc/sys/vm/drop_caches");
//test_read();
}

View File

@@ -71,7 +71,7 @@ def _padding(X, pads=None, auto_pad="NOTSET", axes=None, constant_value=0.):
if pads is None: return X
np_pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
zero_padded = X.pad(tuple(np_pads))
constant_padder = Tensor(np.pad(np.zeros(X.shape), np_pads, constant_values=constant_value), dtype=X.dtype)
constant_padder = Tensor(np.pad(np.zeros(X.shape, dtype=np.float32), np_pads, constant_values=constant_value), dtype=X.dtype)
return zero_padded + constant_padder
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):

View File

@@ -3,10 +3,10 @@ import numpy as np
from tqdm import tqdm
import tempfile, platform
from collections import defaultdict
from tinygrad.helpers import prod, getenv, DEBUG
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyNumpyArray, Device
from tinygrad.lazy import Device
from tinygrad.shape.shapetracker import strides_for_shape
OSX = platform.system() == "Darwin"
@@ -20,6 +20,15 @@ def fetch(url):
with open(fp, "rb") as f:
return f.read()
def fetch_as_file(url):
if url.startswith("/"):
with open(url, "rb") as f:
return f.read()
import os, hashlib, tempfile
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
return fp
def download_file(url, fp, skip_if_exists=True):
import requests, os, pathlib
if skip_if_exists and os.path.isfile(fp) and os.stat(fp).st_size > 0:
@@ -46,7 +55,7 @@ def my_unpickle(fb0):
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
ret = None
else:
ret = Tensor(LazyNumpyArray(lambda lst: np.zeros(lst.shape, dtype=lst.dtype), tuple(size), storage_type))
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
return ret

View File

@@ -2,7 +2,7 @@ import math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d
from extra.utils import fetch, fake_torch_load, get_child
from extra.utils import get_child
class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
@@ -143,8 +143,9 @@ class EfficientNet:
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
}
b0 = fake_torch_load(fetch(model_urls[self.number]))
from extra.utils import fetch_as_file
from tinygrad.state import torch_load
b0 = torch_load(fetch_as_file(model_urls[self.number]))
for k,v in b0.items():
if k.endswith("num_batches_tracked"): continue
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
@@ -155,11 +156,11 @@ class EfficientNet:
#print(k, v.shape)
mv = get_child(self, k)
vnp = v #.astype(np.float32)
vnp = vnp if k != '_fc' else vnp.transpose()
vnp = vnp if k != '_fc' else vnp.cpu().T
#vnp = vnp if vnp.shape != () else np.array([vnp])
if mv.shape == vnp.shape:
mv.assign(vnp)
mv.assign(vnp.to(mv.device))
else:
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))

View File

@@ -27,7 +27,7 @@ class CLCache():
def __exit__(self, type, value, traceback):
print(f"cache: exiting with size {len(GlobalCounters.cache)}", f"allowed {self.allowed}" if self.allowed is not None else "")
if self.allowed is not None:
assert len(GlobalCounters.cache) <= self.allowed and (not self.strict or len(GlobalCounters.cache) == self.allowed), "used too many kernels!"
assert len(GlobalCounters.cache) <= self.allowed and (not self.strict or len(GlobalCounters.cache) == self.allowed), f"used too many kernels! {len(GlobalCounters.cache)} > {self.allowed}"
GlobalCounters.cache = None
from models.convnext import ConvNeXt
@@ -85,7 +85,7 @@ class TestInferenceMinKernels(unittest.TestCase):
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(82):
with CLCache(85):
model(Tensor([[1,2,3,4]]), 0).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")

View File

@@ -2,7 +2,8 @@
import io
import unittest
from tinygrad.helpers import getenv
from extra.utils import fetch, fake_torch_load_zipped
from extra.utils import fetch
from tinygrad.state import torch_load
from PIL import Image
@unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI")
@@ -30,10 +31,10 @@ class TestUtils(unittest.TestCase):
super(LayerWithOffset, self).__init__()
d = torch.randn(16)
self.param1 = torch.nn.Parameter(
d.as_strided([2, 2], [2, 3], storage_offset=5)
d.as_strided([2, 2], [1, 2], storage_offset=5)
)
self.param2 = torch.nn.Parameter(
d.as_strided([2, 2], [2, 3], storage_offset=4)
d.as_strided([2, 2], [1, 2], storage_offset=4)
)
for isfloat16 in [True, False]:
@@ -47,7 +48,7 @@ class TestUtils(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
path = tmpdirname + '/testloadmodel.pth'
torch.save(model.state_dict(), path)
model2 = fake_torch_load_zipped(path)
model2 = torch_load(path)
for name, a in model.state_dict().items():
b = model2[name]

View File

@@ -10,8 +10,8 @@ N = 200 # has to be bigger than the cache to fail
class TestAssign(unittest.TestCase):
def test_simple_assignment(self):
a = Tensor.arange(N*N).reshape(N,N)
b = Tensor.arange(N*N).reshape(N,N)
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
ba1 = a.lazydata.realized
@@ -23,8 +23,8 @@ class TestAssign(unittest.TestCase):
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
def test_permuted_assignment(self):
a = Tensor.arange(N*N).reshape(N,N)
b = Tensor.arange(N*N).reshape(N,N)
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
ba1 = a.lazydata.realized
@@ -37,8 +37,8 @@ class TestAssign(unittest.TestCase):
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
def test_post_permuted_assignment(self):
a = Tensor.arange(N*N).reshape(N,N)
b = Tensor.arange(N*N).reshape(N,N)
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
#GlobalCounters.cache = []

View File

@@ -40,6 +40,8 @@ def kstest(l1, l2):
return prob
def normal_test(func, shape=(20, 23), alpha=0.05):
Tensor.manual_seed(1337)
np.random.seed(1337)
x = func(*shape).cpu().numpy().flatten()
y = np.random.randn(*shape).flatten()
return kstest(x, y) >= alpha

View File

@@ -6,6 +6,28 @@ from tinygrad.state import safe_load, safe_save, get_state_dict
from tinygrad.helpers import dtypes
from tinygrad.runtime.ops_disk import RawDiskBuffer
from extra.helpers import Timing
from extra.utils import fetch_as_file
from tinygrad.state import torch_load, get_state_dict
def compare_weights_both(url):
import torch
fn = fetch_as_file(url)
tg_weights = get_state_dict(torch_load(fn))
torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor)
assert list(tg_weights.keys()) == list(torch_weights.keys())
for k in tg_weights:
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
print(f"compared {len(tg_weights)} weights")
class TestTorchLoad(unittest.TestCase):
# pytorch pkl format
def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
# pytorch zip format
def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth")
# pytorch zip format
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
# TODO: support pytorch tar format with minimal lines
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
test_fn = pathlib.Path(__file__).parent.parent.parent / "weights/LLaMA/7B/consolidated.00.pth"
#test_size = test_fn.stat().st_size
@@ -90,7 +112,7 @@ class TestDiskTensor(unittest.TestCase):
def test_slice(self):
pathlib.Path("/tmp/dt3").unlink(missing_ok=True)
Tensor.arange(10, device="disk:/tmp/dt3").realize()
Tensor.arange(10, device="CPU").to("disk:/tmp/dt3").realize()
slice_me = Tensor.empty(10, device="disk:/tmp/dt3")
print(slice_me)

View File

@@ -448,7 +448,7 @@ class Linearizer:
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0]
if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType):
assert len(unit_stride_axes_mul_4) >= 1, "needs a unit stride axis"
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
self.shift_to(unit_stride_axes_mul_4[0], 4)
self.upcast()

View File

@@ -57,14 +57,6 @@ class ImageDType(DType):
super().__init__()
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
class LazyNumpyArray:
def __init__(self, fxn, shape, dtype): self.fxn, self.shape, self.dtype = fxn, shape, dtype
def __call__(self) -> np.ndarray: return np.require(self.fxn(self) if callable(self.fxn) else self.fxn, dtype=self.dtype, requirements='C').reshape(self.shape)
def reshape(self, new_shape): return LazyNumpyArray(self.fxn, new_shape, self.dtype)
def copy(self): return self if callable(self.fxn) else LazyNumpyArray(self.fxn, self.shape, self.dtype)
def astype(self, typ): return LazyNumpyArray(self.fxn, self.shape, typ)
@dataclass
class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, cast
import sys, weakref, importlib, inspect, functools, pathlib
import numpy as np
from weakref import WeakValueDictionary
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray, flatten, ImageDType, DEBUG
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers
from tinygrad.runtime.lib import RawConst, RawBuffer
@@ -71,7 +72,7 @@ def create_lazybuffer(device:str, shape:Union[ShapeTracker, Tuple[int, ...]], op
st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
# fromcpu aren't cached
if optype == LoadOps and op.op in [LoadOps.FROMCPU, LoadOps.EMPTY]: return LazyBuffer(device, st, optype, op, dtype)
if optype == LoadOps and op.op in [LoadOps.FROMCPU, LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST]: return LazyBuffer(device, st, optype, op, dtype)
#print("create_lazybuffer", device, shape, optype, op, dtype)
@@ -108,13 +109,8 @@ class LazyBuffer:
if self.realized is None:
# get real ops first
if self.op.op == LoadOps.FROMCPU:
if prod(self.op.arg.shape) == 1 and hasattr(Device[self.device].codegen, 'supports_constant_folding'):
self.realized = RawConst(1, dtypes.from_np(self.op.arg.dtype), self.op.arg().flatten()[0])
else:
if DEBUG >= 4: print(f"copying {self.op.arg.shape}:{dtypes.from_np(self.op.arg.dtype)} -> {self.device}")
self.realized = Device[self.device].buffer.fromCPU(self.op.arg(), **self._device_extra_args())
elif self.op.op == LoadOps.EMPTY:
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
if DEBUG >= 4: print(f"copying {self.op.arg.shape}:{dtypes.from_np(self.op.arg.dtype)} -> {self.device}")
self.realized = Device[self.device].buffer.fromCPU(self.op.arg, **self._device_extra_args())
elif self.op.op == LoadOps.CONTIGUOUS:
realized = self.op.src[0].realize().realized
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
@@ -126,6 +122,18 @@ class LazyBuffer:
elif self.op.op == LoadOps.CUSTOM:
# this needs to immediately realize
self.realized = self.op.arg(self, *[x.realize() for x in self.op.src])
elif self.optype == LoadOps:
if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}")
if self.op.op == LoadOps.EMPTY:
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
elif self.op.op == LoadOps.RAND:
rng = np.random.default_rng(self.op.arg)
self.realized = Device[self.device].buffer.fromCPU(rng.random(size=self.shape, dtype=self.dtype.np), **self._device_extra_args())
elif self.op.op == LoadOps.CONST:
if hasattr(Device[self.device].codegen, 'supports_constant_folding'):
self.realized = RawConst(1, self.dtype, float(self.op.arg))
else:
self.realized = Device[self.device].buffer.fromCPU(np.array(self.op.arg, dtype=self.dtype.np), **self._device_extra_args())
# these can be late folded and change the op to go further back in the graph
elif self.optype == ReduceOps: self.op = _ast_reduceops(self)
elif self.optype == BinaryOps: self.op = _ast_binaryops(self) # ISSUE: this can include a reshape
@@ -158,18 +166,14 @@ class LazyBuffer:
del self.op
return self
# NOTE: we have to make a copy of the numpy array here in case the user changes it. expose this? LazyNumpyArray doesn't have this problem
@staticmethod
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
return create_lazybuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x), dtypes.from_np(x.dtype))
@staticmethod
def empty(shape, dtype, device) -> LazyBuffer:
return create_lazybuffer(device, shape, LoadOps, LazyOp(LoadOps.EMPTY, tuple()), dtype)
def loadop(op, shape, dtype, device, arg=None) -> LazyBuffer:
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple(), arg), dtype)
# create a constant with the shape and dtype of self
def const_like(self, val) -> LazyBuffer:
return create_lazybuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \
.movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape)
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?

View File

@@ -67,13 +67,15 @@ class LAMB(Optimizer):
t.assign(t.detach() - self.lr * r * up)
self.realize([self.t] + self.m + self.v)
def get_state_dict(obj, prefix:str='') -> Dict[str, Tensor]:
if isinstance(obj, Tensor): return {prefix.strip('.'):obj}
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix)
from collections import OrderedDict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}."))
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}."))
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())

View File

@@ -12,7 +12,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto();
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]

View File

@@ -14,7 +14,6 @@ class RawDiskBuffer(RawBufferMapped):
with open(device, "a+b") as f:
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = mmap.mmap(f.fileno(), size * dtype.itemsize)
buf.madvise(mmap.MADV_SEQUENTIAL)
# NOTE: we don't call super since disk tensors don't use RAM
self.size, self.dtype, self._buf = size, dtype, buf
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)

View File

@@ -43,7 +43,7 @@ class CLBuffer(RawBufferCopyInOut):
super().__init__(size, dtype, buf)
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, x, is_blocking=False)
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
def _copyout(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
cl.enqueue_copy(CL.cl_queue[self._buf.device], x, self._buf, is_blocking=True)

View File

@@ -1,32 +1,12 @@
import os, json, pathlib
import os, json, pathlib, zipfile, pickle
from typing import Dict, Union
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod
from tinygrad.helpers import dtypes, prod, argsort
from tinygrad.shape.shapetracker import strides_for_shape
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def torch_load(fn:str):
import zipfile, pickle
myzip = zipfile.ZipFile(fn, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
with myzip.open(f'{base_name}/data/{storage[2]}') as myfile:
offset = myfile._orig_compress_start # type: ignore
return t[offset:offset+prod(size)].cast(storage[1]).reshape(size)
intercept = {"HalfStorage": dtypes.float16, "_rebuild_tensor_v2": _rebuild_tensor_v2}
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name):
if module.startswith("torch"): return intercept[name]
return super().find_class(module, name)
def persistent_load(self, pid): return pid
with myzip.open(f'{base_name}/data.pkl') as myfile: return TorchPickle(myfile).load()
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
@@ -48,3 +28,51 @@ def safe_save(tensors:Dict[str, Tensor], fn:str):
# TODO: move get_state_dict and get_parameters here
from tinygrad.nn.optim import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401
# torch support!
def torch_load(fn:str):
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
offsets: Dict[str, int] = {}
lens: Dict[str, int] = {}
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
# 6 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name): return intercept[name] if module.startswith("torch") else super().find_class(module, name)
def persistent_load(self, pid): return pid
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
myzip = zipfile.ZipFile(fn, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
for n in myzip.namelist():
if n.startswith(f'{base_name}/data/'):
with myzip.open(n) as myfile:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
else:
with open(fn, "rb") as f:
pkl = TorchPickle(f)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
f.seek(rwd)
return TorchPickle(f).load()

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import math, functools, itertools, operator
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, LazyNumpyArray
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.ops import LoadOps
# An instantiation of the Function is the Context
class Function:
@@ -33,24 +34,22 @@ class Tensor:
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[int, float, list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
if isinstance(data, (int, float, list)):
if isinstance(data, list):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
elif isinstance(data, LazyBuffer) and data.device != device:
# TODO: this has to realize, it shouldn't have to
data = data.realize().toCPU()
# all ndarrays are lazy now
if isinstance(data, np.ndarray): data = LazyNumpyArray(data, data.shape, data.dtype)
# by here, it's either LazyNumpyArray or LazyBuffer
# TODO: it should all be LazyBuffer I think
if isinstance(data, LazyNumpyArray):
lazydata = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None else data, device)
elif isinstance(data, LazyBuffer):
if isinstance(data, LazyBuffer):
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
lazydata = data
elif isinstance(data, np.ndarray):
lazydata = LazyBuffer.loadop(LoadOps.FROMCPU, data.shape, dtypes.from_np(data.dtype), device, data)
elif isinstance(data, (int, float)):
lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype if dtype is not None else Tensor.default_type, device, data)
else:
raise RuntimeError(f"can't create Tensor from {data}")
@@ -116,6 +115,24 @@ class Tensor:
if self.grad: ret.grad = self.grad.to(device)
return ret
# ***** creation llop entrypoint *****
@staticmethod
def _loadop(op, sz, device=Device.DEFAULT, dtype:Optional[DType]=None, arg=None, **kwargs):
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
_seed: int = 1337
@staticmethod
def manual_seed(seed=None): Tensor._seed = seed
@staticmethod
def rand(*shape, **kwargs):
Tensor._seed += 1
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
# ***** creation helper functions *****
@staticmethod
@@ -127,9 +144,12 @@ class Tensor:
@staticmethod
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
@staticmethod
def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step).cumsum() + (start - step)
@staticmethod
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
return Tensor.full(tensor.shape, fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs)
return Tensor.full(tensor.shape, fill_value=fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs)
@staticmethod
def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs)
@@ -137,39 +157,21 @@ class Tensor:
@staticmethod
def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs)
@staticmethod
def empty(*shape, device=Device.DEFAULT, dtype:Optional[DType]=None, **kwargs):
# NOTE: we do the reshape to fix interpreted buffers
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device)), dtype=dtype, device=device, **kwargs).reshape(*shape)
@staticmethod
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
# TODO: below line, remove use of numpy here and make lazy
# TODO: requires cumsum to remove numpy
@staticmethod
def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
cond = (self != 0.0)
return cond * input_ + (1.0 - cond) * other
# ***** (numpy) rng helper functions *****
# TODO: move randomness generation out of numpy
_rng: ClassVar[np.random.Generator] = np.random.default_rng()
@staticmethod
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed)
@staticmethod
def rand(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.random(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
# TODO: replace with a transformation from uniform -> gaussian
@staticmethod
def randn(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.standard_normal(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*math.pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
@@ -184,7 +186,7 @@ class Tensor:
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound)
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# ***** toposort and backward pass *****
def deepwalk(self):
@@ -460,6 +462,9 @@ class Tensor:
r = (x*w).sum(-1)
return r.reshape((*r.shape[:-2], r.shape[-1])) if len(self.shape) == 1 else r
# TODO: make this work for n-dimensional inputs
def cumsum(self): return self.reshape(1, 1, 1, self.shape[0]).conv2d(Tensor.ones(1, 1, 1, self.shape[0]), padding=(self.shape[0] - 1, 0, 0, 0)).flatten()
# ***** mlops (unary) *****
def contiguous(self): return mlops.Contiguous.apply(self)