diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py new file mode 100644 index 0000000000..f4869de422 --- /dev/null +++ b/examples/mlperf/dataloader.py @@ -0,0 +1,123 @@ +import random, time, ctypes, struct +import numpy as np +from PIL import Image +from tqdm import tqdm +import pickle +from tinygrad import dtypes, Tensor +from tinygrad.helpers import getenv, prod, Timing +from multiprocessing import Queue, Process, shared_memory, connection, Lock + +class MyQueue: + def __init__(self, multiple_readers=True, multiple_writers=True): + self._reader, self._writer = connection.Pipe(duplex=False) + self._rlock = Lock() if multiple_readers else None + self._wlock = Lock() if multiple_writers else None + def get(self): + if self._rlock: self._rlock.acquire() + ret = pickle.loads(self._reader.recv_bytes()) + if self._rlock: self._rlock.release() + return ret + def put(self, obj): + if self._wlock: self._wlock.acquire() + self._writer.send_bytes(pickle.dumps(obj)) + if self._wlock: self._wlock.release() + +def shuffled_indices(n): + indices = {} + for i in range(n-1, -1, -1): + j = random.randint(0, i) + if i not in indices: indices[i] = i + if j not in indices: indices[j] = j + indices[i], indices[j] = indices[j], indices[i] + yield indices[i] + del indices[i] + +def loader_process(q_in, q_out, X:Tensor): + while (_recv := q_in.get()) is not None: + idx, fn = _recv + img = Image.open(fn) + img = img.convert('RGB') if img.mode != "RGB" else img + + # eval: 76.08%, load in 0m7.366s (0m5.301s with simd) + # sudo apt-get install libjpeg-dev + # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd + rescale = min(img.size) / 256 + crop_left = (img.width - 224*rescale) / 2.0 + crop_top = (img.height - 224*rescale) / 2.0 + img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left+224*rescale, crop_top+224*rescale)) + + # broken out + #img_tensor = Tensor(img.tobytes(), device='CPU') + #storage_tensor = X[idx].contiguous().realize().lazydata.realized + #storage_tensor._copyin(img_tensor.numpy()) + + # faster + X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() + + # ideal + #X[idx].assign(img.tobytes()) # NOTE: this is slow! + q_out.put(idx) + +def batch_load_resnet(batch_size=64, val=False, shuffle=True): + from extra.datasets.imagenet import get_train_files, get_val_files + files = get_val_files() if val else get_train_files() + from extra.datasets.imagenet import get_imagenet_categories + cir = get_imagenet_categories() + + BATCH_COUNT = 32 + #q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False) + q_in, q_out = Queue(), Queue() + + sz = (batch_size*BATCH_COUNT, 224, 224, 3) + shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz)) + # disk:shm is slower + #X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}") + X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X") + Y = [None] * (batch_size*BATCH_COUNT) + + procs = [] + for _ in range(64): + p = Process(target=loader_process, args=(q_in, q_out, X)) + p.daemon = True + p.start() + procs.append(p) + + gen = shuffled_indices(len(files)) if shuffle else iter(range(len(files))) + def enqueue_batch(num): + for idx in range(num*batch_size, (num+1)*batch_size): + fn = files[next(gen)] + q_in.put((idx, fn)) + Y[idx] = cir[fn.split("/")[-2]] + for bn in range(BATCH_COUNT): enqueue_batch(bn) + + class Cookie: + def __init__(self, num): self.num = num + def __del__(self): + try: enqueue_batch(self.num) + except StopIteration: pass + + gotten = [0]*BATCH_COUNT + def receive_batch(): + while 1: + num = q_out.get()//batch_size + gotten[num] += 1 + if gotten[num] == batch_size: break + gotten[num] = 0 + return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num) + + # NOTE: this is batch aligned, last ones are ignored + for _ in range(0, len(files)//batch_size): yield receive_batch() + + # shutdown processes + for _ in procs: q_in.put(None) + for p in procs: p.join() + shm.close() + shm.unlink() + +if __name__ == "__main__": + from extra.datasets.imagenet import get_train_files, get_val_files + VAL = getenv("VAL", 1) + files = get_val_files() if VAL else get_train_files() + with tqdm(total=len(files)) as pbar: + for x,y,c in batch_load_resnet(val=VAL): + pbar.update(x.shape[0]) diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 38bf14b567..73e40cf389 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -1,54 +1,67 @@ import time +start = time.perf_counter() from pathlib import Path import numpy as np -from tinygrad import Tensor, GlobalCounters, dtypes +from tinygrad import Tensor, Device, dtypes, GlobalCounters from tinygrad.jit import TinyJit -from tinygrad.helpers import getenv +from tinygrad.nn.state import get_parameters, load_state_dict, safe_load +from tinygrad.helpers import getenv, Timing from examples.mlperf import helpers +def tlog(x): print(f"{x:25s} @ {time.perf_counter()-start:5.2f}s") def eval_resnet(): + Tensor.no_grad = True # Resnet50-v1.5 - from tinygrad.jit import TinyJit from extra.models.resnet import ResNet50 - mdl = ResNet50() - mdl.load_from_pretrained() + tlog("imports") + Device.DEFAULT + tlog("got devices") # NOTE: this is faster with rocm-smi running - input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) - input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) - def input_fixup(x): - x = x.permute([0,3,1,2]).cast(dtypes.float32) / 255.0 - x -= input_mean - x /= input_std - return x + class ResnetRunner: + def __init__(self, device=None): + self.mdl = ResNet50() + for x in get_parameters(self.mdl) if device else []: x.to_(device) + if (fn:=getenv("RESNET_MODEL", "")): load_state_dict(self.mdl, safe_load(fn)) + else: self.mdl.load_from_pretrained() + self.input_mean = Tensor([0.485, 0.456, 0.406], device=device).reshape(1, -1, 1, 1) + self.input_std = Tensor([0.229, 0.224, 0.225], device=device).reshape(1, -1, 1, 1) + def __call__(self, x:Tensor) -> Tensor: + x = x.permute([0,3,1,2]).cast(dtypes.float32) / 255.0 + x -= self.input_mean + x /= self.input_std + return self.mdl(x).argmax(axis=1).realize() - mdlrun = lambda x: mdl(input_fixup(x)).realize() - mdljit = TinyJit(mdlrun) + GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))] + mdljit = [TinyJit(ResnetRunner(d)) for d in GPUS] + tlog("loaded models") # evaluation on the mlperf classes of the validation set from imagenet - from extra.datasets.imagenet import iterate - - BS = 64 + from examples.mlperf.dataloader import batch_load_resnet + iterator = batch_load_resnet(getenv("BS", 128), val=getenv("VAL", 1), shuffle=False) + def data_get(device): + x,y,cookie = next(iterator) + return x.to(device).realize(), y, cookie n,d = 0,0 + proc = [data_get(d) for d in GPUS] + tlog("loaded initial data") st = time.perf_counter() - iterator = iterate(BS) - x,ny = next(iterator) - dat = Tensor(x) - while dat is not None: - y = ny + while proc is not None: GlobalCounters.reset() - mt = time.perf_counter() - outs = mdlrun(dat) if dat.shape[0] != BS else mdljit(dat) - try: - x,ny = next(iterator) - dat = Tensor(x) - except StopIteration: - dat = None - t = outs.argmax(axis=1).numpy() + proc = [(m(x), y, c) for m,(x,y,c) in zip(mdljit, proc)] # this frees the images + run = time.perf_counter() + # load the next data here + try: next_proc = [data_get(d) for d in GPUS] + except StopIteration: next_proc = None + nd = time.perf_counter() + proc = [t.numpy() == y for t, y, _ in proc] # this realizes the models and frees the cookies + for match in proc: + n += match.sum() + d += len(match) et = time.perf_counter() - n += (t==y).sum() - d += len(t) - print(f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS") - st = time.perf_counter() + tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(match)*len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS") + st = et + proc, next_proc = next_proc, None + tlog("done") def eval_unet3d(): # UNet3D @@ -238,4 +251,4 @@ if __name__ == "__main__": nm = f"eval_{m}" if nm in globals(): print(f"eval {m}") - globals()[nm]() \ No newline at end of file + globals()[nm]() diff --git a/extra/datasets/.gitignore b/extra/datasets/.gitignore new file mode 100644 index 0000000000..7b843be36a --- /dev/null +++ b/extra/datasets/.gitignore @@ -0,0 +1,2 @@ +imagenet +imagenet_bak diff --git a/extra/datasets/imagenet.py b/extra/datasets/imagenet.py index dde32a5e40..1770bc12dd 100644 --- a/extra/datasets/imagenet.py +++ b/extra/datasets/imagenet.py @@ -1,27 +1,25 @@ # for imagenet download prepare.sh and run it -import glob, random -import json +import glob, random, json import numpy as np from PIL import Image import functools, pathlib +from tinygrad.helpers import DEBUG, diskcache BASEDIR = pathlib.Path(__file__).parent / "imagenet" -ci = json.load(open(BASEDIR / "imagenet_class_index.json")) -cir = {v[0]: int(k) for k,v in ci.items()} @functools.lru_cache(None) -def get_train_files(): - train_files = open(BASEDIR / "train_files").read().strip().split("\n") - return [(BASEDIR / "train" / x) for x in train_files] +def get_imagenet_categories(): + ci = json.load(open(BASEDIR / "imagenet_class_index.json")) + return {v[0]: int(k) for k,v in ci.items()} + +@diskcache +def get_train_files(): return glob.glob(str(BASEDIR / "train/*/*")) @functools.lru_cache(None) -def get_val_files(): - val_files = glob.glob(str(BASEDIR / "val/*/*")) - return val_files +def get_val_files(): return glob.glob(str(BASEDIR / "val/*/*")) -#rrc = transforms.RandomResizedCrop(224) -import torchvision.transforms.functional as F def image_load(fn): + import torchvision.transforms.functional as F img = Image.open(fn).convert('RGB') img = F.resize(img, 256, Image.BILINEAR) img = F.center_crop(img, 224) @@ -29,8 +27,10 @@ def image_load(fn): return ret def iterate(bs=32, val=True, shuffle=True): + cir = get_imagenet_categories() files = get_val_files() if val else get_train_files() order = list(range(0, len(files))) + if DEBUG >= 1: print(f"imagenet size {len(order)}") if shuffle: random.shuffle(order) from multiprocessing import Pool p = Pool(16) @@ -40,6 +40,7 @@ def iterate(bs=32, val=True, shuffle=True): yield (np.array(X), np.array(Y)) def fetch_batch(bs, val=False): + cir = get_imagenet_categories() files = get_val_files() if val else get_train_files() samp = np.random.randint(0, len(files), size=(bs)) files = [files[i] for i in samp] diff --git a/test/unit/test_disk_cache.py b/test/unit/test_disk_cache.py index 29b9b5215c..5e55de3b6b 100644 --- a/test/unit/test_disk_cache.py +++ b/test/unit/test_disk_cache.py @@ -1,6 +1,6 @@ import unittest import pickle -from tinygrad.helpers import diskcache_get, diskcache_put +from tinygrad.helpers import diskcache_get, diskcache_put, diskcache def remote_get(table,q,k): q.put(diskcache_get(table, k)) def remote_put(table,k,v): diskcache_put(table, k, v) @@ -50,6 +50,20 @@ class DiskCache(unittest.TestCase): self.assertEqual(diskcache_get(table, 4), 5) self.assertEqual(diskcache_get(table, "4"), 5) + def test_decorator(self): + calls = 0 + @diskcache + def hello(x): + nonlocal calls + calls += 1 + return "world"+x + self.assertEqual(hello("bob"), "worldbob") + self.assertEqual(hello("billy"), "worldbilly") + kcalls = calls + self.assertEqual(hello("bob"), "worldbob") + self.assertEqual(hello("billy"), "worldbilly") + self.assertEqual(kcalls, calls) + def test_dict_key(self): table = "test_dict_key" fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")} diff --git a/tinygrad/device.py b/tinygrad/device.py index 230dd32d35..8352110d06 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -79,9 +79,10 @@ class Buffer: if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype) else: self.allocator.free(self._buf, self.size * self.dtype.itemsize) def __repr__(self): return f"" - def as_buffer(self, allow_zero_copy=False) -> memoryview: + def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview: # zero copy with as_buffer (disabled by default due to use after free) - if allow_zero_copy and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf) + if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf) + assert not force_zero_copy, "force zero copy was passed, but copy is required" return self.copyout(memoryview(bytearray(self.size*self.dtype.itemsize))) def copyin(self, mv:memoryview): mv = flat_mv(mv) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 27f0ddf210..d580466aff 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -153,6 +153,13 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any): cur.close() return val +def diskcache(func): + def wrapper(*args, **kwargs) -> bytes: + table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest() + if (ret:=diskcache_get(table, key)): return ret + return diskcache_put(table, key, func(*args, **kwargs)) + return wrapper + # *** http support *** def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 0af6af0179..067201d6be 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -1,7 +1,7 @@ import ctypes from typing import Any, Optional, Tuple, Dict, List, cast import gpuctypes.cuda as cuda -from tinygrad.helpers import init_c_var, encode_args_cuda_style +from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same from tinygrad.device import CompiledASTRunner, update_stats, Buffer from tinygrad.runtime.ops_cuda import check, cu_time_execution from tinygrad.shape.symbolic import Variable @@ -9,7 +9,9 @@ from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_ class CUDAGraph: def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): - if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException + devices = [ji.prg.clprg.device if isinstance(ji.prg, CompiledASTRunner) else None for ji in jit_cache] + if len(devices) == 0 or not all_same(devices) or devices[0] is None: raise GraphException + self.device = devices[0] self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) diff --git a/tinygrad/runtime/graph/hip.py b/tinygrad/runtime/graph/hip.py index 7dabad1985..7cc25a5bf4 100644 --- a/tinygrad/runtime/graph/hip.py +++ b/tinygrad/runtime/graph/hip.py @@ -1,7 +1,9 @@ import ctypes -from typing import Tuple +from typing import Tuple, List, Dict, Optional import gpuctypes.hip as hip from tinygrad.helpers import init_c_var +from tinygrad.device import Buffer +from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_hip import check, hip_time_execution from tinygrad.runtime.graph.cuda import CUDAGraph @@ -9,7 +11,9 @@ class HIPGraph(CUDAGraph): def __del__(self): check(hip.hipGraphDestroy(self.graph)) check(hip.hipGraphExecDestroy(self.instance)) - + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + check(hip.hipSetDevice(self.device)) + return super().__call__(input_rawbuffers, var_vals, wait, jit) def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3)) def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0))) def graph_instantiate(self, graph): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 22a4cec206..288af7441d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -165,8 +165,7 @@ class Tensor: def to_(self, device:Optional[str]): if device is None or device == self.device: return if self.grad: self.grad = self.grad.to_(device) - _ret = Tensor(self.lazydata, device) - self.lazydata = _ret.lazydata + self.lazydata = Tensor(self.lazydata, device).lazydata def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor: assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"