From ffa33d743a7119ab555a726cb0e421162b67e788 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 6 Oct 2023 13:33:24 -0700 Subject: [PATCH] good changes from openpilot_compile2 (#2000) * good changed from openpilot_compile2 * float32 image type was wrong * cleaner way to write that + a test --- extra/onnx.py | 13 +------------ test/test_schedule.py | 5 +++++ test/unit/test_disk_tensor.py | 19 +++++++++++++++++++ tinygrad/codegen/optimizer.py | 27 ++++++++++++++++----------- tinygrad/helpers.py | 6 ++++++ tinygrad/lazy.py | 7 ++++++- tinygrad/nn/image.py | 12 +++++------- tinygrad/nn/state.py | 15 ++++++++------- tinygrad/realize.py | 6 +++++- tinygrad/tensor.py | 11 +++++++---- 10 files changed, 78 insertions(+), 43 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index 3bef138180..b8d186eb56 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -19,8 +19,7 @@ def safe_numpy(t) -> np.ndarray: if not isinstance(t, Tensor): return t global numpy_cache if t not in numpy_cache: - if DEBUG >= 1: - print("numpy cache miss", t) + if DEBUG >= 3: print("numpy cache miss", t) tmp = t.numpy() numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1) assert len(numpy_cache[t].shape) > 0 @@ -95,9 +94,6 @@ def get_run_onnx(onnx_model: ModelProto): print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) print(inp) raise Exception("no data") - if DEBUG >= 1: - print("realize", inp.name) - tensors[inp.name].realize() # preparse the attributes attribute_dict = {} @@ -130,13 +126,6 @@ def get_run_onnx(onnx_model: ModelProto): if shape: # if only input_tensor is not variable type input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]]) assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" - for _,v in input_tensors.items(): - if isinstance(v, Tensor): - v.realize() - elif isinstance(v, list): - for v_ in v: v_.realize() - else: - raise Exception(f"unknown input type: {type(v)}") else: raise Exception(f"no data for {inp.name} with shape {shape}") diff --git a/test/test_schedule.py b/test/test_schedule.py index 2a90f23421..3e97f1bc07 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -322,5 +322,10 @@ class TestSchedule(unittest.TestCase): out = x.permute(0,2,3,1).contiguous() check_schedule(out, 2, filter_loadops=False) + def test_double_from(self): + x = Tensor([1,2,3,4]) + out = x.to('cpu') + check_schedule(out, 0, filter_loadops=False) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 5e5ad88529..37ad99685d 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -86,6 +86,25 @@ class TestSafetensors(unittest.TestCase): for k in f.keys(): np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy()) + def test_huggingface_enet_safetensors(self): + # test a real file + fn = fetch_as_file("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors") + state_dict = safe_load(fn) + assert len(state_dict.keys()) == 244 + assert 'blocks.2.2.se.conv_reduce.weight' in state_dict + assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570 + assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570 + + def test_metadata(self): + metadata = {"hello": "world"} + safe_save({}, temp('metadata.safetensors'), metadata) + import struct + with open(temp('metadata.safetensors'), 'rb') as f: + dat = f.read() + sz = struct.unpack(">Q", dat[0:8])[0] + import json + assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world' + class TestDiskTensor(unittest.TestCase): def test_empty(self): pathlib.Path(temp("dt1")).unlink(missing_ok=True) diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index f71a743380..d4e5c325bb 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -3,8 +3,8 @@ import itertools, math, os from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, BufferOps from tinygrad.codegen.kernel import Kernel, LocalBuffer, LinearizerOptions -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import View +from tinygrad.shape.shapetracker import ShapeTracker, get_contraction +from tinygrad.shape.view import View, strides_for_shape class OptimizedKernel(Kernel): def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None, var_vals=None): @@ -62,6 +62,19 @@ class OptimizedKernel(Kernel): if self.shape_len == 0: return shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] + # if it's an image, insert fake strides such that this fusion doesn't happen across image axes + if self.bufs[0].dtype.name.startswith('image'): + base_shape = self.bufs[0].dtype.shape + if shape_idx_groups := get_contraction(self.output_shape, base_shape): + special_strides: Tuple[int, ...] = tuple() + for i,g in enumerate(shape_idx_groups): + shape_piece = tuple(self.output_shape[x] for x in g) + assert prod(shape_piece) == base_shape[i], "get_contraction was wrong?" + special_strides += strides_for_shape(shape_piece) + # adding the fake image shape + shapes.append(self.output_shape) + strides.append(special_strides) + # merge dimensions if we can, multi get_shape_strides # TODO: does this always preserve the reduce dimension, NO # TODO: move this into shapetracker, with tests! @@ -78,7 +91,7 @@ class OptimizedKernel(Kernel): else: rets[j].append((shapes[j][i], strides[j][i])) # do the reshapes - for i,x in enumerate(rets): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) + for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) # ******************** GPU simplifiers ******************** def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]: @@ -354,14 +367,6 @@ class OptimizedKernel(Kernel): # simplify (sets first_reduce) self.simplify_ones() - # use more opencl indexing if the output buffer is an image and we have room - if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3: - base_shape = self.bufs[0].dtype.shape - if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0: - if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape) - self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) - self.simplify_ones() - # no more opt if we are grouping if self.group_for_reduce: return diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 2200f9db32..7e82874167 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -129,6 +129,12 @@ class dtypes: _float4: Final[DType] = DType(4, 4*4, "float4", None, 4) _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None) + # NOTE: these are image dtypes + @staticmethod + def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp) + @staticmethod + def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp) + # HACK: staticmethods are not callable in 3.8 so we have to compare the class DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod} diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index b35184a192..85757dfe56 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -191,8 +191,13 @@ class LazyBuffer: # NOTE: dtypes.from_np(self.dtype.np) to deal with image types return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape) + def copy_to_device(self, device:str) -> LazyBuffer: + # back off a FROM if it's a double FROM + if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0]) + return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous()) + def contiguous(self:LazyBuffer) -> LazyBuffer: - if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one + if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST) if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const(): # this will turn into nothing, it's based and a copy # TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops diff --git a/tinygrad/nn/image.py b/tinygrad/nn/image.py index 3939ac11f9..72bf4ecbd7 100644 --- a/tinygrad/nn/image.py +++ b/tinygrad/nn/image.py @@ -1,10 +1,6 @@ -import numpy as np -from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes +from tinygrad.helpers import prod, IMAGE, getenv, dtypes from tinygrad.lazy import get_single_root -FLOAT16 = getenv("FLOAT16", 0) -base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "imagef", np.float32) - def image_dot(self, w): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) n1, n2 = len(self.shape), len(w.shape) @@ -27,6 +23,8 @@ def image_dot(self, w): return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order) def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0): + base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef + (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape rcout = cout//groups x, w = self, weight.reshape(groups, rcout, cin, H, W) @@ -56,7 +54,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4) # contiguous creates the image, and early realize static weights (TODO: test for the static weight) - if IMAGE >= 2: x,w = x.cast(ImageDType(*base_image_type, shape=x.shape)), w.cast(ImageDType(*base_image_type, shape=w.shape)) + if IMAGE >= 2: x,w = x.cast(base_image_type(x.shape)), w.cast(base_image_type(w.shape)) x, w = x.contiguous(), w.contiguous() if get_single_root(w.lazydata).realized: w.realize() @@ -86,7 +84,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin # reshape to image and cast back to image ret = ret.reshape(bs*oy, ox*cout//4, 4) - if IMAGE >= 2: ret = ret.cast(ImageDType(*base_image_type, shape=ret.shape)) + if IMAGE >= 2: ret = ret.cast(base_image_type(ret.shape)) if IMAGE >= 3: ret = ret.contiguous() # undo hack for non multiples of 4 on C.rcout diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 67c20b3341..fb427f938a 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,6 +1,6 @@ import os, json, pathlib, zipfile, pickle from tqdm import tqdm -from typing import Dict, Union, List +from typing import Dict, Union, List, Optional, Any from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI from tinygrad.shape.view import strides_for_shape @@ -12,15 +12,16 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") json_len = t[0:1].cast(dtypes.int64).numpy()[0] - metadata = json.loads(t[8:8+json_len].numpy().tobytes()) - return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"} + headers = json.loads(t[8:8+json_len].numpy().tobytes()) + return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in headers.items() if k != "__metadata__"} -def safe_save(tensors:Dict[str, Tensor], fn:str): - metadata, offset = {}, 0 +def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None): + headers, offset = {}, 0 + if metadata: headers['__metadata__'] = metadata for k,v in tensors.items(): - metadata[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]} + headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]} offset += v.nbytes() - j = json.dumps(metadata, separators=(',', ':')) + j = json.dumps(headers, separators=(',', ':')) j += "\x20"*((8-len(j)%8)%8) pathlib.Path(fn).unlink(missing_ok=True) t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 9edba50161..427cc1ed1a 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -50,6 +50,7 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]] while len(schedule): op,out,buffers = schedule.pop(0) log_schedule_item(op, out, buffers) + assert all(x.realized for x in buffers), "can't run schedule, some buffers aren't realized" if DEBUG >= 3: from extra.utils import print_tree # type: ignore print_tree(op) @@ -68,10 +69,12 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]] def _realize_empty(buffer: LazyBuffer) -> None: assert all_int(buffer.shape), "does not support symbolic shape" + if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}") buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) def _realize_rand(buffer: LazyBuffer) -> None: assert all_int(buffer.shape), "does not support symbolic shape" + if DEBUG >= 2: print(f"*** rand {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}") rng = np.random.default_rng(buffer.op.arg) buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) @@ -80,7 +83,7 @@ def _realize_rand(buffer: LazyBuffer) -> None: def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None: assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}" assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from" - if DEBUG >= 3: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size} dtype {src.realized.dtype}") + if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}") # TODO: make this generic if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped): assert all_int(buffer.shape), "does not support symbolic shape" @@ -95,6 +98,7 @@ def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None: # *** n op LoadOps *** def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None: + if DEBUG >= 2: print(f"*** custom {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}") buffer.realized = buffer.op.arg(buffer, *inputs) LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = { diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 73571e2561..6b4fd11779 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -61,17 +61,20 @@ class Tensor: self._ctx: Optional[Function] = None if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" elif isinstance(data, (int, float)): - self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data) - return + data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data) elif data.__class__ is list: assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np)) elif isinstance(data, np.ndarray): assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" - data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) + if data.shape == (): + data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) + else: + data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) else: raise RuntimeError(f"can't create Tensor from {data}") - self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data.contiguous()) + # data is a LazyBuffer, but it might be on the wrong device + self.lazydata = data if data.device == device else data.copy_to_device(device) def __repr__(self): return f""