mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
ruff lint tinykitten (#13762)
deleted used import and double spaces. a few ignore to not change the real code
This commit is contained in:
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@@ -236,6 +236,7 @@ jobs:
|
||||
pip3 install --upgrade --force-reinstall ruff==0.11.0
|
||||
python3 -m ruff check .
|
||||
python3 -m ruff check examples/mlperf/ --ignore E501
|
||||
python3 -m ruff check extra/thunder/tiny/ --ignore E501 --ignore F841 --ignore E722
|
||||
- name: Run mypy
|
||||
run: |
|
||||
python -m mypy --strict-equality --lineprecision-report .
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.kernel import Kernel
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import math, functools
|
||||
import math
|
||||
from typing import cast, Callable
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import AxisType, UOp, Ops
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.helpers import getenv, prod
|
||||
from tinygrad.helpers import prod
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout, VecLayout
|
||||
from extra.thunder.tiny.tk.tiles import ALL_TILES, ST, RT, RV, TileLayout, VecLayout
|
||||
|
||||
class Group:
|
||||
def __init__(self, warps:int, ker):
|
||||
@@ -83,7 +82,7 @@ class Group:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
|
||||
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
@@ -113,7 +112,7 @@ class Group:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
|
||||
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
@@ -143,7 +142,7 @@ class Group:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
|
||||
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
@@ -173,7 +172,7 @@ class Group:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ()) # type: ignore
|
||||
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
|
||||
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
|
||||
@@ -2,7 +2,7 @@ from contextlib import AbstractContextManager
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, AxisType, AddrSpace
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.group import Group
|
||||
from extra.thunder.tiny.tk.tiles import GL, ST_16X16, ST_16X16_SWIZZLED, ST, RT_16X16, RT, RV, TileLayout, VecLayout
|
||||
from extra.thunder.tiny.tk.tiles import GL, ST_16X16, ST, RT_16X16, RT, RV, TileLayout, VecLayout
|
||||
|
||||
class _tk_range:
|
||||
def __init__(self, start:int, end:int, step:int, axis_type:AxisType, rid:int):
|
||||
|
||||
Reference in New Issue
Block a user