remove dtype from FlopCounter (#4075)

the annoying thing to remove all FlopCounter is that for device that does not support local, matmul index alu is huge.
we can remove the dtype first.

sneak in updating `ruff` command to `ruff check`
This commit is contained in:
chenyu
2024-04-04 21:23:28 -04:00
committed by GitHub
parent 3de855ea50
commit 9e0ebf8979
3 changed files with 10 additions and 11 deletions

View File

@@ -84,7 +84,7 @@ jobs:
- name: Lint with ruff
run: |
pip3 install --upgrade --force-reinstall ruff
python3 -m ruff . --preview
python3 -m ruff check . --preview
- name: Lint tinygrad with pylint
run: python -m pylint tinygrad/
- name: Run mypy

View File

@@ -9,7 +9,7 @@ repos:
pass_filenames: false
- id: ruff
name: ruff
entry: ruff . --preview
entry: ruff check . --preview
language: system
always_run: true
pass_filenames: false

View File

@@ -78,7 +78,6 @@ class LazyOp:
@dataclass
class FlopCounter:
shape: Tuple[int, ...]
dtype: DType
flops: sint
mem: Dict[int, int]
@property
@@ -88,14 +87,14 @@ class FlopCounter:
return ret
InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.real_size()}),
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST}, # noqa: E501
**{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST},
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
@functools.lru_cache(None)
def get_lazyop_info(ast:LazyOp) -> FlopCounter: