mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Refactor LoadOps (#910)
* test * work * upd test * loadops * cleanups * real ones * remove LazyNumpyArray * fix assign test * remove range * np.require * llama uses arange kernels * no caching consts * fix enet * torch load support * tests cleanup * fix shufflenet * fix image * fix torch_load test
This commit is contained in:
@@ -134,7 +134,7 @@ assert len(lazyop.src) == 2
|
||||
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
||||
print(lazyop.src[0].op)
|
||||
assert lazyop.src[0].op.op == LoadOps.FROMCPU
|
||||
assert lazyop.src[0].op.arg.fxn == [2], "the arg of the FROMCPU LazyOP is the [2.]"
|
||||
assert lazyop.src[0].op.arg == [2], "the arg of the FROMCPU LazyOP is the [2.]"
|
||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||
|
||||
# now we realize the LazyBuffer
|
||||
|
||||
1
extra/disk/.gitignore
vendored
Normal file
1
extra/disk/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
a.out
|
||||
71
extra/disk/test.cc
Normal file
71
extra/disk/test.cc
Normal file
@@ -0,0 +1,71 @@
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/mman.h>
|
||||
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
|
||||
#define SZ (unsigned long long)(512*1024*1024)
|
||||
#define CNT 10LL
|
||||
|
||||
void test_read() {
|
||||
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
|
||||
printf("open %d\n", f);
|
||||
|
||||
/*void *buf = malloc(CNT*SZ);
|
||||
printf("malloc %p\n", buf);
|
||||
mlock(buf, CNT*SZ);*/
|
||||
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
void *buf = mmap(NULL, SZ*CNT, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
|
||||
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
mlock(buf, CNT*SZ);
|
||||
for (int i = 0; i < CNT; i++) {
|
||||
read(f, (unsigned char*)buf+SZ*i, SZ);
|
||||
}
|
||||
auto t2 = std::chrono::high_resolution_clock::now();
|
||||
|
||||
//free(buf);
|
||||
printf("malloc %p\n", buf);
|
||||
float ns = (float)std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
|
||||
float pns = (float)std::chrono::duration_cast<std::chrono::nanoseconds>(t1-t0).count();
|
||||
printf("read %.2f GB in %.2f s (%.2f s to prepare), %.2f GB/s\n", SZ/1e9*CNT, ns*1e-9, pns*1e-9, (SZ*CNT)/ns);
|
||||
|
||||
close(f);
|
||||
munmap(buf, SZ*CNT);
|
||||
}
|
||||
|
||||
void test_mmap() {
|
||||
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
|
||||
printf("open %d\n", f);
|
||||
|
||||
void *dat = mmap(NULL, SZ*CNT, PROT_READ, MAP_PRIVATE, f, 0);
|
||||
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
mlock(dat, SZ*CNT);
|
||||
auto t2 = std::chrono::high_resolution_clock::now();
|
||||
|
||||
printf("mmap %p\n", dat);
|
||||
|
||||
float ns = (float)std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
|
||||
printf("read %.2f GB in %.2f s, %.2f GB/s\n", SZ/1e9*CNT, ns*1e-9, (SZ*CNT)/ns);
|
||||
|
||||
close(f);
|
||||
munlock(dat, SZ*CNT);
|
||||
munmap(dat, SZ*CNT);
|
||||
}
|
||||
|
||||
int main() {
|
||||
system("sync; echo 1 > /proc/sys/vm/drop_caches");
|
||||
test_mmap();
|
||||
|
||||
//system("sync; echo 1 > /proc/sys/vm/drop_caches");
|
||||
//test_read();
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ def _padding(X, pads=None, auto_pad="NOTSET", axes=None, constant_value=0.):
|
||||
if pads is None: return X
|
||||
np_pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
|
||||
zero_padded = X.pad(tuple(np_pads))
|
||||
constant_padder = Tensor(np.pad(np.zeros(X.shape), np_pads, constant_values=constant_value), dtype=X.dtype)
|
||||
constant_padder = Tensor(np.pad(np.zeros(X.shape, dtype=np.float32), np_pads, constant_values=constant_value), dtype=X.dtype)
|
||||
return zero_padded + constant_padder
|
||||
|
||||
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):
|
||||
|
||||
@@ -3,10 +3,10 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
import tempfile, platform
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import prod, getenv, DEBUG
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyNumpyArray, Device
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.shape.shapetracker import strides_for_shape
|
||||
OSX = platform.system() == "Darwin"
|
||||
|
||||
@@ -20,6 +20,15 @@ def fetch(url):
|
||||
with open(fp, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def fetch_as_file(url):
|
||||
if url.startswith("/"):
|
||||
with open(url, "rb") as f:
|
||||
return f.read()
|
||||
import os, hashlib, tempfile
|
||||
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
|
||||
return fp
|
||||
|
||||
def download_file(url, fp, skip_if_exists=True):
|
||||
import requests, os, pathlib
|
||||
if skip_if_exists and os.path.isfile(fp) and os.stat(fp).st_size > 0:
|
||||
@@ -46,7 +55,7 @@ def my_unpickle(fb0):
|
||||
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
|
||||
ret = None
|
||||
else:
|
||||
ret = Tensor(LazyNumpyArray(lambda lst: np.zeros(lst.shape, dtype=lst.dtype), tuple(size), storage_type))
|
||||
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
|
||||
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
|
||||
return ret
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import BatchNorm2d
|
||||
from extra.utils import fetch, fake_torch_load, get_child
|
||||
from extra.utils import get_child
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
|
||||
@@ -143,8 +143,9 @@ class EfficientNet:
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
||||
}
|
||||
|
||||
|
||||
b0 = fake_torch_load(fetch(model_urls[self.number]))
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.state import torch_load
|
||||
b0 = torch_load(fetch_as_file(model_urls[self.number]))
|
||||
for k,v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||
@@ -155,11 +156,11 @@ class EfficientNet:
|
||||
#print(k, v.shape)
|
||||
mv = get_child(self, k)
|
||||
vnp = v #.astype(np.float32)
|
||||
vnp = vnp if k != '_fc' else vnp.transpose()
|
||||
vnp = vnp if k != '_fc' else vnp.cpu().T
|
||||
#vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
|
||||
if mv.shape == vnp.shape:
|
||||
mv.assign(vnp)
|
||||
mv.assign(vnp.to(mv.device))
|
||||
else:
|
||||
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
||||
|
||||
|
||||
4
test/external/external_test_opt.py
vendored
4
test/external/external_test_opt.py
vendored
@@ -27,7 +27,7 @@ class CLCache():
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(f"cache: exiting with size {len(GlobalCounters.cache)}", f"allowed {self.allowed}" if self.allowed is not None else "")
|
||||
if self.allowed is not None:
|
||||
assert len(GlobalCounters.cache) <= self.allowed and (not self.strict or len(GlobalCounters.cache) == self.allowed), "used too many kernels!"
|
||||
assert len(GlobalCounters.cache) <= self.allowed and (not self.strict or len(GlobalCounters.cache) == self.allowed), f"used too many kernels! {len(GlobalCounters.cache)} > {self.allowed}"
|
||||
GlobalCounters.cache = None
|
||||
|
||||
from models.convnext import ConvNeXt
|
||||
@@ -85,7 +85,7 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
with CLCache(82):
|
||||
with CLCache(85):
|
||||
model(Tensor([[1,2,3,4]]), 0).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
import io
|
||||
import unittest
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.utils import fetch, fake_torch_load_zipped
|
||||
from extra.utils import fetch
|
||||
from tinygrad.state import torch_load
|
||||
from PIL import Image
|
||||
|
||||
@unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI")
|
||||
@@ -30,10 +31,10 @@ class TestUtils(unittest.TestCase):
|
||||
super(LayerWithOffset, self).__init__()
|
||||
d = torch.randn(16)
|
||||
self.param1 = torch.nn.Parameter(
|
||||
d.as_strided([2, 2], [2, 3], storage_offset=5)
|
||||
d.as_strided([2, 2], [1, 2], storage_offset=5)
|
||||
)
|
||||
self.param2 = torch.nn.Parameter(
|
||||
d.as_strided([2, 2], [2, 3], storage_offset=4)
|
||||
d.as_strided([2, 2], [1, 2], storage_offset=4)
|
||||
)
|
||||
|
||||
for isfloat16 in [True, False]:
|
||||
@@ -47,7 +48,7 @@ class TestUtils(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
path = tmpdirname + '/testloadmodel.pth'
|
||||
torch.save(model.state_dict(), path)
|
||||
model2 = fake_torch_load_zipped(path)
|
||||
model2 = torch_load(path)
|
||||
|
||||
for name, a in model.state_dict().items():
|
||||
b = model2[name]
|
||||
|
||||
@@ -10,8 +10,8 @@ N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
class TestAssign(unittest.TestCase):
|
||||
def test_simple_assignment(self):
|
||||
a = Tensor.arange(N*N).reshape(N,N)
|
||||
b = Tensor.arange(N*N).reshape(N,N)
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
ba1 = a.lazydata.realized
|
||||
@@ -23,8 +23,8 @@ class TestAssign(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
|
||||
|
||||
def test_permuted_assignment(self):
|
||||
a = Tensor.arange(N*N).reshape(N,N)
|
||||
b = Tensor.arange(N*N).reshape(N,N)
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
ba1 = a.lazydata.realized
|
||||
@@ -37,8 +37,8 @@ class TestAssign(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
def test_post_permuted_assignment(self):
|
||||
a = Tensor.arange(N*N).reshape(N,N)
|
||||
b = Tensor.arange(N*N).reshape(N,N)
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
#GlobalCounters.cache = []
|
||||
|
||||
@@ -40,6 +40,8 @@ def kstest(l1, l2):
|
||||
return prob
|
||||
|
||||
def normal_test(func, shape=(20, 23), alpha=0.05):
|
||||
Tensor.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
x = func(*shape).cpu().numpy().flatten()
|
||||
y = np.random.randn(*shape).flatten()
|
||||
return kstest(x, y) >= alpha
|
||||
|
||||
@@ -6,6 +6,28 @@ from tinygrad.state import safe_load, safe_save, get_state_dict
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from extra.helpers import Timing
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.state import torch_load, get_state_dict
|
||||
|
||||
def compare_weights_both(url):
|
||||
import torch
|
||||
fn = fetch_as_file(url)
|
||||
tg_weights = get_state_dict(torch_load(fn))
|
||||
torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor)
|
||||
assert list(tg_weights.keys()) == list(torch_weights.keys())
|
||||
for k in tg_weights:
|
||||
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
|
||||
print(f"compared {len(tg_weights)} weights")
|
||||
|
||||
class TestTorchLoad(unittest.TestCase):
|
||||
# pytorch pkl format
|
||||
def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
# pytorch zip format
|
||||
def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth")
|
||||
# pytorch zip format
|
||||
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
|
||||
# TODO: support pytorch tar format with minimal lines
|
||||
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
|
||||
test_fn = pathlib.Path(__file__).parent.parent.parent / "weights/LLaMA/7B/consolidated.00.pth"
|
||||
#test_size = test_fn.stat().st_size
|
||||
@@ -90,7 +112,7 @@ class TestDiskTensor(unittest.TestCase):
|
||||
|
||||
def test_slice(self):
|
||||
pathlib.Path("/tmp/dt3").unlink(missing_ok=True)
|
||||
Tensor.arange(10, device="disk:/tmp/dt3").realize()
|
||||
Tensor.arange(10, device="CPU").to("disk:/tmp/dt3").realize()
|
||||
|
||||
slice_me = Tensor.empty(10, device="disk:/tmp/dt3")
|
||||
print(slice_me)
|
||||
|
||||
@@ -448,7 +448,7 @@ class Linearizer:
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType):
|
||||
assert len(unit_stride_axes_mul_4) >= 1, "needs a unit stride axis"
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
self.shift_to(unit_stride_axes_mul_4[0], 4)
|
||||
self.upcast()
|
||||
|
||||
@@ -57,14 +57,6 @@ class ImageDType(DType):
|
||||
super().__init__()
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
class LazyNumpyArray:
|
||||
def __init__(self, fxn, shape, dtype): self.fxn, self.shape, self.dtype = fxn, shape, dtype
|
||||
def __call__(self) -> np.ndarray: return np.require(self.fxn(self) if callable(self.fxn) else self.fxn, dtype=self.dtype, requirements='C').reshape(self.shape)
|
||||
def reshape(self, new_shape): return LazyNumpyArray(self.fxn, new_shape, self.dtype)
|
||||
def copy(self): return self if callable(self.fxn) else LazyNumpyArray(self.fxn, self.shape, self.dtype)
|
||||
def astype(self, typ): return LazyNumpyArray(self.fxn, self.shape, typ)
|
||||
|
||||
|
||||
@dataclass
|
||||
class dtypes:
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, cast
|
||||
import sys, weakref, importlib, inspect, functools, pathlib
|
||||
import numpy as np
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray, flatten, ImageDType, DEBUG
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer
|
||||
@@ -71,7 +72,7 @@ def create_lazybuffer(device:str, shape:Union[ShapeTracker, Tuple[int, ...]], op
|
||||
st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
|
||||
# fromcpu aren't cached
|
||||
if optype == LoadOps and op.op in [LoadOps.FROMCPU, LoadOps.EMPTY]: return LazyBuffer(device, st, optype, op, dtype)
|
||||
if optype == LoadOps and op.op in [LoadOps.FROMCPU, LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST]: return LazyBuffer(device, st, optype, op, dtype)
|
||||
|
||||
#print("create_lazybuffer", device, shape, optype, op, dtype)
|
||||
|
||||
@@ -108,13 +109,8 @@ class LazyBuffer:
|
||||
if self.realized is None:
|
||||
# get real ops first
|
||||
if self.op.op == LoadOps.FROMCPU:
|
||||
if prod(self.op.arg.shape) == 1 and hasattr(Device[self.device].codegen, 'supports_constant_folding'):
|
||||
self.realized = RawConst(1, dtypes.from_np(self.op.arg.dtype), self.op.arg().flatten()[0])
|
||||
else:
|
||||
if DEBUG >= 4: print(f"copying {self.op.arg.shape}:{dtypes.from_np(self.op.arg.dtype)} -> {self.device}")
|
||||
self.realized = Device[self.device].buffer.fromCPU(self.op.arg(), **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.EMPTY:
|
||||
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
|
||||
if DEBUG >= 4: print(f"copying {self.op.arg.shape}:{dtypes.from_np(self.op.arg.dtype)} -> {self.device}")
|
||||
self.realized = Device[self.device].buffer.fromCPU(self.op.arg, **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.CONTIGUOUS:
|
||||
realized = self.op.src[0].realize().realized
|
||||
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
|
||||
@@ -126,6 +122,18 @@ class LazyBuffer:
|
||||
elif self.op.op == LoadOps.CUSTOM:
|
||||
# this needs to immediately realize
|
||||
self.realized = self.op.arg(self, *[x.realize() for x in self.op.src])
|
||||
elif self.optype == LoadOps:
|
||||
if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}")
|
||||
if self.op.op == LoadOps.EMPTY:
|
||||
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.RAND:
|
||||
rng = np.random.default_rng(self.op.arg)
|
||||
self.realized = Device[self.device].buffer.fromCPU(rng.random(size=self.shape, dtype=self.dtype.np), **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.CONST:
|
||||
if hasattr(Device[self.device].codegen, 'supports_constant_folding'):
|
||||
self.realized = RawConst(1, self.dtype, float(self.op.arg))
|
||||
else:
|
||||
self.realized = Device[self.device].buffer.fromCPU(np.array(self.op.arg, dtype=self.dtype.np), **self._device_extra_args())
|
||||
# these can be late folded and change the op to go further back in the graph
|
||||
elif self.optype == ReduceOps: self.op = _ast_reduceops(self)
|
||||
elif self.optype == BinaryOps: self.op = _ast_binaryops(self) # ISSUE: this can include a reshape
|
||||
@@ -158,18 +166,14 @@ class LazyBuffer:
|
||||
del self.op
|
||||
return self
|
||||
|
||||
# NOTE: we have to make a copy of the numpy array here in case the user changes it. expose this? LazyNumpyArray doesn't have this problem
|
||||
@staticmethod
|
||||
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
|
||||
return create_lazybuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x), dtypes.from_np(x.dtype))
|
||||
|
||||
@staticmethod
|
||||
def empty(shape, dtype, device) -> LazyBuffer:
|
||||
return create_lazybuffer(device, shape, LoadOps, LazyOp(LoadOps.EMPTY, tuple()), dtype)
|
||||
def loadop(op, shape, dtype, device, arg=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple(), arg), dtype)
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const_like(self, val) -> LazyBuffer:
|
||||
return create_lazybuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \
|
||||
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
|
||||
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \
|
||||
.movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape)
|
||||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
|
||||
@@ -67,13 +67,15 @@ class LAMB(Optimizer):
|
||||
t.assign(t.detach() - self.lr * r * up)
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
def get_state_dict(obj, prefix:str='') -> Dict[str, Tensor]:
|
||||
if isinstance(obj, Tensor): return {prefix.strip('.'):obj}
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix)
|
||||
from collections import OrderedDict
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
||||
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
||||
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
||||
state_dict = {}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}."))
|
||||
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
||||
elif isinstance(obj, dict):
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}."))
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||||
return state_dict
|
||||
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
|
||||
|
||||
@@ -12,7 +12,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto();
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]
|
||||
|
||||
@@ -14,7 +14,6 @@ class RawDiskBuffer(RawBufferMapped):
|
||||
with open(device, "a+b") as f:
|
||||
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
||||
buf = mmap.mmap(f.fileno(), size * dtype.itemsize)
|
||||
buf.madvise(mmap.MADV_SEQUENTIAL)
|
||||
# NOTE: we don't call super since disk tensors don't use RAM
|
||||
self.size, self.dtype, self._buf = size, dtype, buf
|
||||
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
|
||||
@@ -43,7 +43,7 @@ class CLBuffer(RawBufferCopyInOut):
|
||||
super().__init__(size, dtype, buf)
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, x, is_blocking=False)
|
||||
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
|
||||
def _copyout(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
|
||||
cl.enqueue_copy(CL.cl_queue[self._buf.device], x, self._buf, is_blocking=True)
|
||||
|
||||
@@ -1,32 +1,12 @@
|
||||
import os, json, pathlib
|
||||
import os, json, pathlib, zipfile, pickle
|
||||
from typing import Dict, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.helpers import dtypes, prod, argsort
|
||||
from tinygrad.shape.shapetracker import strides_for_shape
|
||||
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def torch_load(fn:str):
|
||||
import zipfile, pickle
|
||||
myzip = zipfile.ZipFile(fn, 'r')
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
with myzip.open(f'{base_name}/data/{storage[2]}') as myfile:
|
||||
offset = myfile._orig_compress_start # type: ignore
|
||||
return t[offset:offset+prod(size)].cast(storage[1]).reshape(size)
|
||||
|
||||
intercept = {"HalfStorage": dtypes.float16, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
if module.startswith("torch"): return intercept[name]
|
||||
return super().find_class(module, name)
|
||||
def persistent_load(self, pid): return pid
|
||||
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile: return TorchPickle(myfile).load()
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
|
||||
@@ -48,3 +28,51 @@ def safe_save(tensors:Dict[str, Tensor], fn:str):
|
||||
|
||||
# TODO: move get_state_dict and get_parameters here
|
||||
from tinygrad.nn.optim import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
# torch support!
|
||||
|
||||
def torch_load(fn:str):
|
||||
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
|
||||
offsets: Dict[str, int] = {}
|
||||
lens: Dict[str, int] = {}
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||||
|
||||
# 6 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
|
||||
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
||||
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
||||
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||||
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
|
||||
|
||||
return ret.reshape(size)
|
||||
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name): return intercept[name] if module.startswith("torch") else super().find_class(module, name)
|
||||
def persistent_load(self, pid): return pid
|
||||
|
||||
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
|
||||
myzip = zipfile.ZipFile(fn, 'r')
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
for n in myzip.namelist():
|
||||
if n.startswith(f'{base_name}/data/'):
|
||||
with myzip.open(n) as myfile:
|
||||
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
return TorchPickle(myfile).load()
|
||||
else:
|
||||
with open(fn, "rb") as f:
|
||||
pkl = TorchPickle(f)
|
||||
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
|
||||
for i in ids:
|
||||
offsets[i] = base_offset + 8
|
||||
base_offset += 8 + lens[i]
|
||||
f.seek(rwd)
|
||||
return TorchPickle(f).load()
|
||||
|
||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
||||
import math, functools, itertools, operator
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, LazyNumpyArray
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
@@ -33,24 +34,22 @@ class Tensor:
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
if isinstance(data, (int, float, list)):
|
||||
if isinstance(data, list):
|
||||
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
|
||||
elif isinstance(data, LazyBuffer) and data.device != device:
|
||||
# TODO: this has to realize, it shouldn't have to
|
||||
data = data.realize().toCPU()
|
||||
|
||||
# all ndarrays are lazy now
|
||||
if isinstance(data, np.ndarray): data = LazyNumpyArray(data, data.shape, data.dtype)
|
||||
|
||||
# by here, it's either LazyNumpyArray or LazyBuffer
|
||||
# TODO: it should all be LazyBuffer I think
|
||||
if isinstance(data, LazyNumpyArray):
|
||||
lazydata = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None else data, device)
|
||||
elif isinstance(data, LazyBuffer):
|
||||
if isinstance(data, LazyBuffer):
|
||||
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
lazydata = data
|
||||
elif isinstance(data, np.ndarray):
|
||||
lazydata = LazyBuffer.loadop(LoadOps.FROMCPU, data.shape, dtypes.from_np(data.dtype), device, data)
|
||||
elif isinstance(data, (int, float)):
|
||||
lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype if dtype is not None else Tensor.default_type, device, data)
|
||||
else:
|
||||
raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
||||
@@ -116,6 +115,24 @@ class Tensor:
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
# ***** creation llop entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
def _loadop(op, sz, device=Device.DEFAULT, dtype:Optional[DType]=None, arg=None, **kwargs):
|
||||
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
|
||||
|
||||
_seed: int = 1337
|
||||
@staticmethod
|
||||
def manual_seed(seed=None): Tensor._seed = seed
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, **kwargs):
|
||||
Tensor._seed += 1
|
||||
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
@@ -127,9 +144,12 @@ class Tensor:
|
||||
@staticmethod
|
||||
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
|
||||
return Tensor.full(tensor.shape, fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs)
|
||||
return Tensor.full(tensor.shape, fill_value=fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs)
|
||||
@@ -137,39 +157,21 @@ class Tensor:
|
||||
@staticmethod
|
||||
def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, device=Device.DEFAULT, dtype:Optional[DType]=None, **kwargs):
|
||||
# NOTE: we do the reshape to fix interpreted buffers
|
||||
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device)), dtype=dtype, device=device, **kwargs).reshape(*shape)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
# TODO: below line, remove use of numpy here and make lazy
|
||||
# TODO: requires cumsum to remove numpy
|
||||
@staticmethod
|
||||
def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
|
||||
|
||||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
cond = (self != 0.0)
|
||||
return cond * input_ + (1.0 - cond) * other
|
||||
|
||||
# ***** (numpy) rng helper functions *****
|
||||
# TODO: move randomness generation out of numpy
|
||||
|
||||
_rng: ClassVar[np.random.Generator] = np.random.default_rng()
|
||||
@staticmethod
|
||||
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed)
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.random(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
|
||||
|
||||
# TODO: replace with a transformation from uniform -> gaussian
|
||||
@staticmethod
|
||||
def randn(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.standard_normal(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
@staticmethod
|
||||
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
src = Tensor.rand(2, *shape, **kwargs)
|
||||
return src[0].mul(2*math.pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
|
||||
|
||||
@@ -184,7 +186,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound)
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
def deepwalk(self):
|
||||
@@ -460,6 +462,9 @@ class Tensor:
|
||||
r = (x*w).sum(-1)
|
||||
return r.reshape((*r.shape[:-2], r.shape[-1])) if len(self.shape) == 1 else r
|
||||
|
||||
# TODO: make this work for n-dimensional inputs
|
||||
def cumsum(self): return self.reshape(1, 1, 1, self.shape[0]).conv2d(Tensor.ones(1, 1, 1, self.shape[0]), padding=(self.shape[0] - 1, 0, 0, 0)).flatten()
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
|
||||
Reference in New Issue
Block a user