From ed1963b8991a41335ea19cd94a6eddb10870f708 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 3 Jun 2023 12:25:41 -0700 Subject: [PATCH] Fast DiskTensor to other Tensor (#916) * make disktensors fast * loading * loader for sd and llama --- examples/llama.py | 13 ++++------ examples/stable_diffusion.py | 26 ++++---------------- extra/disk/test.cc | 25 ++++++++++++++++---- extra/helpers.py | 9 +------ tinygrad/helpers.py | 9 ++++++- tinygrad/lazy.py | 15 +++++++++--- tinygrad/nn/optim.py | 17 ++++--------- tinygrad/ops.py | 2 +- tinygrad/runtime/ops_disk.py | 16 +++++++++---- tinygrad/state.py | 46 ++++++++++++++++++++++++++++++------ tinygrad/tensor.py | 7 +++--- 11 files changed, 109 insertions(+), 76 deletions(-) diff --git a/examples/llama.py b/examples/llama.py index 9c07d20490..8d5c5edadc 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -207,6 +207,7 @@ if __name__ == "__main__": args = parser.parse_args() chatbot = args.prompt == None + """ # load model (you have to find the weights yourself) from extra.utils import fake_torch_load_zipped, get_child @@ -262,18 +263,12 @@ if __name__ == "__main__": get_child(model, k).assign(v).realize() del weights + """ # disktensor loader isn't fast yet - """ - from tinygrad.state import torch_load, get_state_dict - state_dict = torch_load(WEIGHTS_7B_FILENAME) model = Transformer(**args_7B) - with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"): - for k,v in (t := tqdm(get_state_dict(model).items())): - t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k}") - if k not in state_dict: continue - v.assign(state_dict[k].to(v.device)).realize() - """ + from tinygrad.state import torch_load, load_state_dict + load_state_dict(model, torch_load(WEIGHTS_7B_FILENAME), strict=False) # *** prompt engineers work here **** diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 40cfab57cd..7a04cf8115 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -2,10 +2,7 @@ # https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md from pathlib import Path -import gzip -import argparse -import math -import re +import gzip, argparse, math, re from functools import lru_cache from collections import namedtuple @@ -14,7 +11,8 @@ from tqdm import tqdm from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm -from extra.utils import fake_torch_load_zipped, get_child, download_file +from extra.utils import download_file +from tinygrad.state import torch_load, load_state_dict # TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code @@ -613,24 +611,10 @@ if __name__ == "__main__": model = StableDiffusion() # load in weights - download_file( - 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', - FILENAME - ) - dat = fake_torch_load_zipped(open(FILENAME, "rb")) - for k,v in dat['state_dict'].items(): - try: - w = get_child(model, k) - except (AttributeError, KeyError, IndexError): - #traceback.print_exc() - w = None - #print(f"{str(v.shape):30s}" if v is not None else v, w.shape if w is not None else w, k) - if w is not None: - assert w.shape == v.shape and w.dtype == v.dtype, f"shape or dtype mismatch. {w.shape} != {v.shape} or {w.dtype} != {v.dtype}" - w.assign(v) + download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME) + load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False) # run through CLIP to get context - tokenizer = ClipTokenizer() prompt = tokenizer.encode(args.prompt) context = model.cond_stage_model.transformer.text_model(prompt).realize() diff --git a/extra/disk/test.cc b/extra/disk/test.cc index 497b5e1f50..eb1d5da815 100644 --- a/extra/disk/test.cc +++ b/extra/disk/test.cc @@ -10,11 +10,19 @@ #include #include +//#define FN "/dev/nvme0n1" +#define FN "../../weights/LLaMA/7B/consolidated.00.pth" + #define SZ (unsigned long long)(512*1024*1024) #define CNT 10LL void test_read() { - int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT); +#ifdef O_DIRECT + int f = open(FN, O_RDONLY|O_DIRECT); +#else + int f = open(FN, O_RDONLY); + //fcntl(f, F_NOCACHE, 1); +#endif printf("open %d\n", f); /*void *buf = malloc(CNT*SZ); @@ -42,7 +50,11 @@ void test_read() { } void test_mmap() { - int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT); +#ifdef O_DIRECT + int f = open(FN, O_RDONLY|O_DIRECT); +#else + int f = open(FN, O_RDONLY); +#endif printf("open %d\n", f); void *dat = mmap(NULL, SZ*CNT, PROT_READ, MAP_PRIVATE, f, 0); @@ -62,10 +74,13 @@ void test_mmap() { } int main() { - system("sync; echo 1 > /proc/sys/vm/drop_caches"); - test_mmap(); + //system("sync; echo 1 > /proc/sys/vm/drop_caches"); + //system("sudo purge"); + //test_mmap(); //system("sync; echo 1 > /proc/sys/vm/drop_caches"); - //test_read(); + system("sudo purge"); + test_read(); + test_read(); } diff --git a/extra/helpers.py b/extra/helpers.py index 0084963ef9..7a1daae15e 100644 --- a/extra/helpers.py +++ b/extra/helpers.py @@ -1,11 +1,4 @@ -import time - -class Timing(object): - def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled - def __enter__(self): self.st = time.perf_counter_ns() - def __exit__(self, exc_type, exc_val, exc_tb): - self.et = time.perf_counter_ns() - self.st - if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else "")) +from tinygrad.helpers import Timing def enable_early_exec(): import subprocess, multiprocessing diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index a8087f29c9..450282b3ab 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, asdict -import os, math, functools +import os, math, functools, time import numpy as np from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any ShapeType = Tuple[int, ...] @@ -39,6 +39,13 @@ class ContextVar: DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0) +class Timing(object): + def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled + def __enter__(self): self.st = time.perf_counter_ns() + def __exit__(self, exc_type, exc_val, exc_tb): + self.et = time.perf_counter_ns() - self.st + if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else "")) + # **** tinygrad now supports dtypes! ***** class DType(NamedTuple): diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 827999abe7..74bbc2bade 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -6,7 +6,8 @@ from weakref import WeakValueDictionary from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers -from tinygrad.runtime.lib import RawConst, RawBuffer +from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped +from tinygrad.runtime.ops_disk import RawDiskBuffer # lazy can recurse a lot sys.setrecursionlimit(10000) @@ -122,6 +123,14 @@ class LazyBuffer: elif self.op.op == LoadOps.CUSTOM: # this needs to immediately realize self.realized = self.op.arg(self, *[x.realize() for x in self.op.src]) + elif self.op.op == LoadOps.FROM: + rawbuf = self.op.src[0].realize() + # TODO: make this generic + if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[self.device].buffer, RawBufferMapped): + self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args()) + rawbuf.realized.readinto(cast(RawBufferMapped, self.realized)._buffer()) + else: + self.realized = Device[self.device].buffer.fromCPU(rawbuf.toCPU(), **self._device_extra_args()) elif self.optype == LoadOps: if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}") if self.op.op == LoadOps.EMPTY: @@ -167,8 +176,8 @@ class LazyBuffer: return self @staticmethod - def loadop(op, shape, dtype, device, arg=None) -> LazyBuffer: - return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple(), arg), dtype) + def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: + return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) # create a constant with the shape and dtype of self def const_like(self, val) -> LazyBuffer: diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index b4a5d19c5a..346e9c1663 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -1,5 +1,5 @@ # sorted in order of increasing complexity -from typing import List, Dict +from typing import List from tinygrad.tensor import Tensor class Optimizer: @@ -67,15 +67,6 @@ class LAMB(Optimizer): t.assign(t.detach() - self.lr * r * up) self.realize([self.t] + self.m + self.v) -from collections import OrderedDict -def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]: - if isinstance(obj, tensor_type): return {prefix.strip('.'):obj} - if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type) - if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type) - state_dict = {} - if isinstance(obj, (list, tuple)): - for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type)) - elif isinstance(obj, dict): - for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) - return state_dict -def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values()) +# TODO: remove this +from tinygrad.state import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401 + diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6635b8ca42..f2cd5e0a49 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -12,7 +12,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto(); class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class FusedOps(Enum): MULACC = auto() # noqa: E702 -class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702 +class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]] diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 6e07c937ca..8ef16060d3 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -11,11 +11,16 @@ class RawDiskBuffer(RawBufferMapped): self.offset = offset # this is an offset in bytes assert device is not None or buf is not None, "disk tensor needs a path or a buf" if device is not None: - with open(device, "a+b") as f: - if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize) - buf = mmap.mmap(f.fileno(), size * dtype.itemsize) + f = open(device, "a+b") + if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize) + buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1] + else: + buf[2] += 1 # NOTE: we don't call super since disk tensors don't use RAM self.size, self.dtype, self._buf = size, dtype, buf + def __del__(self): + self._buf[2] -= 1 + if self._buf[2] == 0: self._buf[0].close() def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset) def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset) def shrink(self, arg): @@ -23,7 +28,10 @@ class RawDiskBuffer(RawBufferMapped): offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:]) return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:]) - def _buffer(self): return memoryview(self._buf)[self.offset:self.offset+self.size*self.dtype.itemsize] + def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize] + def readinto(self, buf): + self._buf[0].seek(self.offset) + self._buf[0].readinto(buf) disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink } diff --git a/tinygrad/state.py b/tinygrad/state.py index 658e71ac65..567c058b92 100644 --- a/tinygrad/state.py +++ b/tinygrad/state.py @@ -1,7 +1,8 @@ import os, json, pathlib, zipfile, pickle -from typing import Dict, Union +from tqdm import tqdm +from typing import Dict, Union, List from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes, prod, argsort +from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters from tinygrad.shape.shapetracker import strides_for_shape safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64} @@ -26,8 +27,30 @@ def safe_save(tensors:Dict[str, Tensor], fn:str): t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8)) for k,v in safe_load(t).items(): v.assign(tensors[k]) -# TODO: move get_state_dict and get_parameters here -from tinygrad.nn.optim import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401 +# state dict + +from collections import OrderedDict +def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]: + if isinstance(obj, tensor_type): return {prefix.strip('.'):obj} + if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple + if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type) + if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type) + state_dict = {} + if isinstance(obj, (list, tuple)): + for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type)) + elif isinstance(obj, dict): + for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) + return state_dict +def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values()) + +def load_state_dict(model, state_dict, strict=True): + with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"): + for k,v in (t := tqdm(get_state_dict(model).items())): + t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}") + if k not in state_dict and not strict: + if DEBUG >= 2: print(f"WARNING: not loading {k}") + continue + v.assign(state_dict[k].to(v.device)).realize() # torch support! @@ -43,19 +66,28 @@ def torch_load(fn:str): byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1]) - # 6 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk + # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])] if tuple(permute_indexes) != tuple(range(len(permute_indexes))): intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)]) assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides" + if DEBUG >= 2: print(f"WARNING: this torch load is slow. it has to convert to CPU to permute {permute_indexes}") + # TODO: find a nice way to support all shapetracker on disktensors ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes) return ret.reshape(size) - intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2} + intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2} + whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed + class Dummy: pass class TorchPickle(pickle.Unpickler): - def find_class(self, module, name): return intercept[name] if module.startswith("torch") else super().find_class(module, name) + def find_class(self, module, name): + module_root = module.split(".")[0] + if module_root not in whitelist: + if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}") + return Dummy + return intercept[name] if module_root == "torch" else super().find_class(module, name) def persistent_load(self, pid): return pid if tuple(t[0:2].numpy()) == (0x50, 0x4b): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 15ea17439f..24b8b8f3c7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -39,14 +39,12 @@ class Tensor: device = Device.canonicalize(device) if isinstance(data, list): data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np) - elif isinstance(data, LazyBuffer) and data.device != device: - # TODO: this has to realize, it shouldn't have to - data = data.realize().toCPU() if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - lazydata = data + lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) elif isinstance(data, np.ndarray): + # TODO: create CPUBuffer directly lazydata = LazyBuffer.loadop(LoadOps.FROMCPU, data.shape, dtypes.from_np(data.dtype), device, data) elif isinstance(data, (int, float)): lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype if dtype is not None else Tensor.default_type, device, data) @@ -94,6 +92,7 @@ class Tensor: self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore return self if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype) + # NOTE: we are currently allowing assignments from different dtypes assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}" assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")