mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
remove dtype from FlopCounter (#4075)
the annoying thing to remove all FlopCounter is that for device that does not support local, matmul index alu is huge. we can remove the dtype first. sneak in updating `ruff` command to `ruff check`
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -84,7 +84,7 @@ jobs:
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
pip3 install --upgrade --force-reinstall ruff
|
||||
python3 -m ruff . --preview
|
||||
python3 -m ruff check . --preview
|
||||
- name: Lint tinygrad with pylint
|
||||
run: python -m pylint tinygrad/
|
||||
- name: Run mypy
|
||||
|
||||
@@ -9,7 +9,7 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: ruff
|
||||
name: ruff
|
||||
entry: ruff . --preview
|
||||
entry: ruff check . --preview
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -78,7 +78,6 @@ class LazyOp:
|
||||
@dataclass
|
||||
class FlopCounter:
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
flops: sint
|
||||
mem: Dict[int, int]
|
||||
@property
|
||||
@@ -88,14 +87,14 @@ class FlopCounter:
|
||||
return ret
|
||||
|
||||
InterpretedFlopCounter: Dict[Op, Callable] = {
|
||||
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.real_size()}),
|
||||
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST}, # noqa: E501
|
||||
**{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
||||
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
||||
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST},
|
||||
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
||||
|
||||
Reference in New Issue
Block a user