From c8fbdeb48ef98dc259999ac0463f57bc9ba4ff13 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 25 Jun 2023 15:22:56 -0700 Subject: [PATCH] test speed llama (#1046) * test speed llama * oops, put it back * uses the real device codegen * just do it on the mac * pp * is faster? * Revert "is faster?" This reverts commit 42db542010906dd62376c0e419416978d03d3d62. * disable docker again for less load on CI --- .github/workflows/test.yml | 4 +- test/external/external_test_speed_llama.py | 49 ++++++++++++++++++++++ test/test_net_speed.py | 4 +- test/unit/test_example.py | 2 +- tinygrad/lazy.py | 2 +- tinygrad/runtime/ops_fake.py | 17 ++++++++ tinygrad/tensor.py | 4 +- 7 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 test/external/external_test_speed_llama.py create mode 100644 tinygrad/runtime/ops_fake.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 70571ec4ef..4e22612cde 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -207,16 +207,18 @@ jobs: python-version: 3.11 - name: Install Dependencies run: pip install -e '.[metal,testing]' + - name: Test LLaMA compile speed + run: PYTHONPATH="." METAL=1 python3 test/external/external_test_speed_llama.py #- name: Run dtype test # run: DEBUG=4 METAL=1 python -m pytest test/test_dtype.py - name: Run ops test run: DEBUG=2 METAL=1 python -m pytest test/test_ops.py # dtype test has issues on test_half_to_int8 - # disabled, this test is flaky testdocker: name: Docker Test runs-on: ubuntu-latest + if: ${{ false }} steps: - name: Checkout Code diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py new file mode 100644 index 0000000000..8d1d308896 --- /dev/null +++ b/test/external/external_test_speed_llama.py @@ -0,0 +1,49 @@ +# NOTE: this only tests the speed of the LLaMA codegen, it doesn't actually run the net +import unittest, time +from examples.llama import Transformer, args_7B +from test.test_net_speed import start_profile, stop_profile +from tinygrad.tensor import Tensor +from tinygrad.helpers import getenv +from tinygrad.lazy import Device +from tinygrad.state import get_state_dict +from tinygrad.ops import Compiled + +class TestLLaMASpeed(unittest.TestCase): + @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends") + def test_llama_compile(self): + # TODO: with default device + old_default = Device.DEFAULT + Device.DEFAULT = "FAKE" + + # use the codegen from the real device + Device['fake'].codegen = Device[old_default].codegen + print("using", Device['fake'].codegen) + + print("testing llama python run time") + model = Transformer(**args_7B) + print("built model") + # assign fake tensors to the values + for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype)) + print("assigned empty tensors, doing warmup") + + def run_llama(st, empty_method_cache=True): + #print(f"clearing {len(Device['fake'].method_cache)} from method cache") + if empty_method_cache: Device['fake'].method_cache.clear() + tms = [time.perf_counter()] + for i in range(5): + model(Tensor([[2]]), i).realize() + tms.append(time.perf_counter()) + print(f"{st:15s} runtime in ms:", ', '.join("%.2f"%((tms[i+1]-tms[i])*1000) for i in range(len(tms)-1))) + + run_llama("compile") + run_llama("methodcache", False) + + pr = start_profile() + run_llama("profile") + stop_profile(pr, sort='time', frac=0.1) + + # reset device + Device.DEFAULT = old_default + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_net_speed.py b/test/test_net_speed.py index df72d3dc8f..e382783505 100644 --- a/test/test_net_speed.py +++ b/test/test_net_speed.py @@ -12,12 +12,12 @@ def start_profile(): pr.enable() return pr -def stop_profile(pr, sort='cumtime'): +def stop_profile(pr, sort='cumtime', frac=0.2): pr.disable() ps = pstats.Stats(pr) ps.strip_dirs() ps.sort_stats(sort) - ps.print_stats(0.2) + ps.print_stats(frac) class TestConvSpeed(unittest.TestCase): diff --git a/test/unit/test_example.py b/test/unit/test_example.py index 4101bb2c2c..db5e7a7aea 100644 --- a/test/unit/test_example.py +++ b/test/unit/test_example.py @@ -8,7 +8,7 @@ def multidevice_test(fxn): exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",") def ret(self): for device in Device._buffers: - if device == "DISK": continue + if device in ["DISK", "FAKE"]: continue print(device) if device in exclude_devices: print(f"WARNING: {device} test is excluded") diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 19be4b4c9d..8461468c0c 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -314,7 +314,7 @@ class _Device: def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device() - def canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") + def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper()) @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0] diff --git a/tinygrad/runtime/ops_fake.py b/tinygrad/runtime/ops_fake.py new file mode 100644 index 0000000000..d5cbbd348b --- /dev/null +++ b/tinygrad/runtime/ops_fake.py @@ -0,0 +1,17 @@ +# used for compilation only speed tests +import numpy as np +from tinygrad.helpers import dtypes, prod +from tinygrad.ops import Compiled +from tinygrad.runtime.lib import RawBuffer + +class RawFakeBuffer(RawBuffer): + @classmethod + def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs) + def toCPU(self): return np.empty(self.size, dtype=self.dtype.np) + +class FakeProgram: + def __init__(self, name:str, prg:str): pass + def __call__(self, global_size, local_size, *args, wait=False): pass + +# NOTE: you have to set a codegen to use this +FakeBuffer = Compiled(RawFakeBuffer, None, FakeProgram) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a95c70e5c4..31da98860c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -39,7 +39,7 @@ class Tensor: no_grad: ClassVar[bool] = False default_type: ClassVar[DType] = dtypes.float32 - def __init__(self, data:Union[int, float, list, tuple, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): + def __init__(self, data:Union[int, float, list, tuple, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = Device.canonicalize(device) # tensors have gradients, buffers do not @@ -124,7 +124,7 @@ class Tensor: # ***** creation llop entrypoint ***** @staticmethod - def _loadop(op, sz, device=Device.DEFAULT, dtype:Optional[DType]=None, arg=None, **kwargs): + def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs): return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) @staticmethod