diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..c9689b2532 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: local + hooks: + - id: pylint + name: pylint + entry: pylint tinygrad/ + language: system + always_run: true + pass_filenames: false + - id: mypy + name: mypy + entry: mypy tinygrad/ --ignore-missing-imports + language: system + always_run: true + pass_filenames: false diff --git a/setup.py b/setup.py index cd4897230f..83b046172d 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,8 @@ setup(name='tinygrad', "onnx", "onnx2torch", "mypy", + "pylint", + "pre-commit", ], }, include_package_data=True) diff --git a/tinygrad/ast.py b/tinygrad/ast.py index ac43cadd56..01351d2668 100644 --- a/tinygrad/ast.py +++ b/tinygrad/ast.py @@ -116,12 +116,12 @@ class ASTKernel: for i in range(1, len(shapes[0])): can_merge = [] for j in range(len(shapes)): - # TODO: added the always mergability of 1s, is this right? if so, add to shapetracker in the 1 case + # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) # more can merge than this - can_merge = all(can_merge) and i != self.first_reduce + mergeable = all(can_merge) and i != self.first_reduce for j in range(len(shapes)): - if can_merge: + if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) else: rets[j].append((shapes[j][i], strides[j][i])) diff --git a/tinygrad/shape/__init__.py b/tinygrad/shape/__init__.py index fc2d591879..1aa286a0db 100644 --- a/tinygrad/shape/__init__.py +++ b/tinygrad/shape/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations import os import functools -from typing import Tuple, Union, List, Optional +from typing import Tuple, Union, List, Optional, Any from tinygrad.helpers import prod from tinygrad.shape.symbolic import Variable @@ -21,7 +21,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup return ret class View: - def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0): + def __init__(self, shape:Union[Tuple[int, ...],List[Any]], strides:Union[Tuple[int, ...],List[Any]], offset:int=0): self.shape, self.strides, self.offset = tuple(shape), tuple(strides), offset self.shape_strides = to_shape_strides(self.shape, self.strides)