From cec0a7bc3783955a5b9a586fa5abc0aa9c97fcab Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 15 Jan 2024 16:49:38 -0800 Subject: [PATCH] use shard api to eval resnet fast (#3136) * use shard api to eval resnet fast * to supports shard * test to in multitensor --- examples/mlperf/model_eval.py | 27 +++++++++++++-------------- test/test_multitensor.py | 11 +++++++++-- tinygrad/tensor.py | 12 +++++++----- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 73e40cf389..fe2bc3ad3d 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -14,7 +14,8 @@ def eval_resnet(): # Resnet50-v1.5 from extra.models.resnet import ResNet50 tlog("imports") - Device.DEFAULT + GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))] + for x in GPUS: Device[x] tlog("got devices") # NOTE: this is faster with rocm-smi running class ResnetRunner: @@ -31,34 +32,32 @@ def eval_resnet(): x /= self.input_std return self.mdl(x).argmax(axis=1).realize() - GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))] - mdljit = [TinyJit(ResnetRunner(d)) for d in GPUS] + mdl = ResnetRunner(GPUS) tlog("loaded models") # evaluation on the mlperf classes of the validation set from imagenet from examples.mlperf.dataloader import batch_load_resnet - iterator = batch_load_resnet(getenv("BS", 128), val=getenv("VAL", 1), shuffle=False) - def data_get(device): + iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False) + def data_get(): x,y,cookie = next(iterator) - return x.to(device).realize(), y, cookie + return x.shard(GPUS, axis=0).realize(), y, cookie n,d = 0,0 - proc = [data_get(d) for d in GPUS] + proc = data_get() tlog("loaded initial data") st = time.perf_counter() while proc is not None: GlobalCounters.reset() - proc = [(m(x), y, c) for m,(x,y,c) in zip(mdljit, proc)] # this frees the images + proc = (mdl(proc[0]), proc[1], proc[2]) # this frees the images run = time.perf_counter() # load the next data here - try: next_proc = [data_get(d) for d in GPUS] + try: next_proc = data_get() except StopIteration: next_proc = None nd = time.perf_counter() - proc = [t.numpy() == y for t, y, _ in proc] # this realizes the models and frees the cookies - for match in proc: - n += match.sum() - d += len(match) + proc = proc[0].numpy() == proc[1] # this realizes the models and frees the cookies + n += proc.sum() + d += proc.size et = time.perf_counter() - tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(match)*len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS") + tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS") st = et proc, next_proc = next_proc, None tlog("done") diff --git a/test/test_multitensor.py b/test/test_multitensor.py index e5d7ac57b3..2568d05c1e 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -2,8 +2,6 @@ import unittest from tinygrad import Tensor, Device, nn, GlobalCounters from tinygrad.helpers import CI from tinygrad.nn.state import get_parameters -from extra.lr_scheduler import OneCycleLR -from extra.models.llama import RMSNorm, Attention import numpy as np d_zero = f"{Device.DEFAULT}:0" @@ -18,6 +16,12 @@ N = 128 @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") class TestMultiTensor(unittest.TestCase): + def test_to(self): + X = Tensor.ones(256).contiguous().realize() + X.to_((d0, d1)) + for lb in X.lazydata.lbs: + assert lb.shape == (256,) + def test_shard(self): X = Tensor.ones(256).contiguous().realize() X.shard_((d0, d1), 0) @@ -135,6 +139,7 @@ class TestMultiTensor(unittest.TestCase): optim.step() def test_lr_scheduler_OneCycleLR(self): + from extra.lr_scheduler import OneCycleLR conv = nn.Conv2d(3, 16, 3) for p in get_parameters(conv): p.shard_((d0, d1)) optim = nn.optim.SGD(get_parameters(conv)) @@ -156,6 +161,7 @@ class TestMultiTensor(unittest.TestCase): np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6) def test_rmsnorm(self): + from extra.models.llama import RMSNorm B, T, embed_size = 4, 10, 20 layer_norm = RMSNorm(embed_size) @@ -180,6 +186,7 @@ class TestMultiTensor(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipIf(Device.DEFAULT == "GPU", "GPU requires cl_khr_fp16") def _test_llama_attention(self, device): + from extra.models.llama import Attention bs, seq_len, dim, n_heads, n_kv_heads, max_context = 1, 1, 128, 4, 4, 32 freqs_cis = Tensor.rand(1, seq_len, 1, (dim//n_heads)//2, 2).half() mask = None diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 288af7441d..4e981bc249 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -156,16 +156,18 @@ class Tensor: assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape) - def to(self, device:Optional[str]) -> Tensor: + def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor: if device is None or device == self.device: return self + if not isinstance(device, str): return self.shard(device) ret = Tensor(self.lazydata, device) if self.grad: ret.grad = self.grad.to(device) return ret - def to_(self, device:Optional[str]): - if device is None or device == self.device: return - if self.grad: self.grad = self.grad.to_(device) - self.lazydata = Tensor(self.lazydata, device).lazydata + def to_(self, device:Optional[Union[str, Tuple[str, ...]]]): + real = self.to(device) + # TODO: is this assign? + if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata + self.lazydata = real.lazydata def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor: assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"