From 9e0ebf8979090b35c84e17f19f69a4450d284e03 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 4 Apr 2024 21:23:28 -0400 Subject: [PATCH] remove dtype from FlopCounter (#4075) the annoying thing to remove all FlopCounter is that for device that does not support local, matmul index alu is huge. we can remove the dtype first. sneak in updating `ruff` command to `ruff check` --- .github/workflows/test.yml | 2 +- .pre-commit-config.yaml | 2 +- tinygrad/ops.py | 17 ++++++++--------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9d2ec38ae1..8ffd6fe31d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,7 +84,7 @@ jobs: - name: Lint with ruff run: | pip3 install --upgrade --force-reinstall ruff - python3 -m ruff . --preview + python3 -m ruff check . --preview - name: Lint tinygrad with pylint run: python -m pylint tinygrad/ - name: Run mypy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bff44f59db..13a0324240 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: pass_filenames: false - id: ruff name: ruff - entry: ruff . --preview + entry: ruff check . --preview language: system always_run: true pass_filenames: false diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 18ecbaa781..be5bf1d7bc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -78,7 +78,6 @@ class LazyOp: @dataclass class FlopCounter: shape: Tuple[int, ...] - dtype: DType flops: sint mem: Dict[int, int] @property @@ -88,14 +87,14 @@ class FlopCounter: return ret InterpretedFlopCounter: Dict[Op, Callable] = { - BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.real_size()}), - BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}), - BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501 - UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops - **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST}, # noqa: E501 - **{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 - **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501 - TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501 + BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}), + BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}), + BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}), + UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops + **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST}, + **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 + **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501 + TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501 @functools.lru_cache(None) def get_lazyop_info(ast:LazyOp) -> FlopCounter: