From 4f8f0ac1393db379cbab6170abfd7ad3ca4b2f48 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 23 Nov 2023 09:01:50 -0800 Subject: [PATCH] minor cleanups, remove dead files (#2398) * minor cleanups, remove dead files * s.name * use disk * pytest passes on mac --- compile.sh | 8 -------- extra/utils.py | 2 +- rmso.sh | 3 --- test/extra/test_utils.py | 4 +--- test/models/test_train.py | 2 +- test/models/test_whisper.py | 3 ++- test/test_copy_speed.py | 12 ++++++++++-- tinygrad/tensor.py | 5 ++--- 8 files changed, 17 insertions(+), 22 deletions(-) delete mode 100755 compile.sh delete mode 100755 rmso.sh diff --git a/compile.sh b/compile.sh deleted file mode 100755 index 14b9f86e40..0000000000 --- a/compile.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -# note: if we compile tinygrad/nn/__init__.py __dict__ no longer works, and optimizers will silently fail -mypyc --check-untyped-defs --explicit-package-bases --warn-unreachable tinygrad/shape/shapetracker.py tinygrad/shape/symbolic.py \ - tinygrad/helpers.py tinygrad/mlops.py tinygrad/tensor.py tinygrad/graph.py \ - #tinygrad/codegen/gpu.py tinygrad/runtime/ops_metal.py - #tinygrad/codegen/ast.py - #tinygrad/nn/__init__.py - #tinygrad/ops.py tinygrad/runtime/ops_metal.py tinygrad/runtime/ops_gpu.py tinygrad/runtime/ops_cpu.py tinygrad/lazy.py diff --git a/extra/utils.py b/extra/utils.py index b261831136..db2163892a 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -37,7 +37,7 @@ def fetch_as_file(url): def download_file(url, fp, skip_if_exists=True): if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0: return - r = requests.get(url, stream=True) + r = requests.get(url, stream=True, timeout=10) assert r.status_code == 200 progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url) (path := Path(fp).parent).mkdir(parents=True, exist_ok=True) diff --git a/rmso.sh b/rmso.sh deleted file mode 100755 index 30dc7e621a..0000000000 --- a/rmso.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -rm tinygrad/*.so tinygrad/codegen/*.so tinygrad/shape/*.so tinygrad/nn/*.so tinygrad/runtime/*.so *.so - diff --git a/test/extra/test_utils.py b/test/extra/test_utils.py index 2c0b831839..4b47c01269 100644 --- a/test/extra/test_utils.py +++ b/test/extra/test_utils.py @@ -14,9 +14,7 @@ from PIL import Image @unittest.skipIf(CI, "no internet tests in CI") class TestFetch(unittest.TestCase): def test_fetch_bad_http(self): - self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500') - self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404') - self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400') + self.assertRaises(AssertionError, fetch, 'http://www.google.com/404') def test_fetch_small(self): assert(len(fetch('https://google.com'))>0) diff --git a/test/models/test_train.py b/test/models/test_train.py index 22e51c931b..7f4b9161f8 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -49,7 +49,7 @@ class TestTrain(unittest.TestCase): train_one_step(model,X,Y) check_gc() - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu") + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_vit(self): model = ViT() X = np.zeros((BS,3,224,224), dtype=np.float32) diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index 1f8acaaa44..4b93e041ae 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -62,4 +62,5 @@ class TestWhisper(unittest.TestCase): with self.assertRaises(Exception): transcribe_waveform(self.model, self.enc, waveforms) - +if __name__ == '__main__': + unittest.main() diff --git a/test/test_copy_speed.py b/test/test_copy_speed.py index dc4b532abc..4f78c1cb04 100644 --- a/test/test_copy_speed.py +++ b/test/test_copy_speed.py @@ -2,6 +2,7 @@ import unittest from tinygrad import Tensor from tinygrad.ops import Device from tinygrad.helpers import Timing, CI +import multiprocessing.shared_memory as shared_memory N = 4096 if CI else 16384 class TestCopySpeed(unittest.TestCase): @@ -9,13 +10,18 @@ class TestCopySpeed(unittest.TestCase): def setUpClass(cls): Device[Device.DEFAULT].synchronize() def testCopySHMtoDefault(self): - t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize() - #t = Tensor.empty(N, N, device="disk:shm:test_X").realize() + s = shared_memory.SharedMemory(name="test_X", create=True, size=N*N*4) + s.close() + if CI: + t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize() + else: + t = Tensor.empty(N, N, device="disk:shm:test_X").realize() for _ in range(3): with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): with Timing("queue: "): t.to(Device.DEFAULT).realize() Device[Device.DEFAULT].synchronize() + s.unlink() def testCopyCPUtoDefault(self): t = Tensor.rand(N, N, device="cpu").realize() @@ -45,6 +51,8 @@ class TestCopySpeed(unittest.TestCase): @unittest.skipIf(CI, "CI doesn't have 6 GPUs") def testCopyCPUto6GPUs(self): + from tinygrad.runtime.ops_gpu import CL + if len(CL.devices) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs") t = Tensor.rand(N, N, device="cpu").realize() print(f"buffer: {t.nbytes()*1e-9:.2f} GB") for _ in range(3): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index dcd5488fdd..66bc41068d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -72,7 +72,7 @@ class Tensor: data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) - else: raise RuntimeError(f"can't create Tensor from {data}") + else: raise RuntimeError(f"can't create Tensor from {data} with type {type(data)}") # data is a LazyBuffer, but it might be on the wrong device self.lazydata = data if data.device == device else data.copy_to_device(device) @@ -665,7 +665,7 @@ class Tensor: return (x, y) def _to_float(self, x:Union[Tensor, float]): - return x.lazydata.op.arg if isinstance(x, Tensor) and not x.lazydata.realized and x.lazydata.op.op == LoadOps.CONST and not x.requires_grad \ + return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_const() and not x.requires_grad \ and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: @@ -715,7 +715,6 @@ class Tensor: # ***** binary op wrappers (18 wasted lines to make the typechecker happy) ***** - # NOTE: __pow__ and friends are broken in mypyc with the ** operator def __add__(self, x) -> Tensor: return self.add(x) def __sub__(self, x) -> Tensor: return self.sub(x) def __mul__(self, x) -> Tensor: return self.mul(x)