mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
add E275 missing-whitespace-after-keyword linting rule (#6149)
requires space after keywords like `assert`, `not`, `return`, `else`
This commit is contained in:
@@ -12,6 +12,7 @@ lint.select = [
|
||||
# "E124",
|
||||
"E203", # whitespace-before-punctuation
|
||||
"E272", # multiple-spaces-before-keyword
|
||||
"E275", # missing-whitespace-after-keyword
|
||||
"E303", # too-many-blank-lines
|
||||
"E304", # blank-line-after-decorator
|
||||
"E501", # line-too-long
|
||||
|
||||
@@ -13,25 +13,25 @@ class TestGC(unittest.TestCase):
|
||||
a = Tensor.rand(4, 4, requires_grad=True)
|
||||
b = Tensor.zeros(4, 4, requires_grad=True)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() > 0)
|
||||
assert (tensors_allocated() > 0)
|
||||
del a,b
|
||||
assert(tensors_allocated() == 1) # one for Tensor._rng_counter
|
||||
assert (tensors_allocated() == 1) # one for Tensor._rng_counter
|
||||
|
||||
def test_gc_complex(self):
|
||||
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
b = Tensor.rand(4, 4, requires_grad=True)
|
||||
assert(tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 3)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() == 5)
|
||||
assert (tensors_allocated() == 5)
|
||||
del b
|
||||
assert(tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 3)
|
||||
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
print(tensors_allocated())
|
||||
(a*b).mean().backward()
|
||||
print(tensors_allocated())
|
||||
assert(tensors_allocated() == 5)
|
||||
assert (tensors_allocated() == 5)
|
||||
del b
|
||||
assert(tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -148,7 +148,7 @@ class TestFetch(unittest.TestCase):
|
||||
self.assertRaises(Exception, fetch, 'http://www.google.com/404', allow_caching=False)
|
||||
|
||||
def test_fetch_small(self):
|
||||
assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0)
|
||||
assert (len(fetch('https://google.com', allow_caching=False).read_bytes())>0)
|
||||
|
||||
def test_fetch_img(self):
|
||||
img = fetch("https://avatars.githubusercontent.com/u/132956020", allow_caching=False)
|
||||
|
||||
@@ -367,38 +367,38 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
self.st = self.st.expand((10, 10))
|
||||
self.st = self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
self.st = self.st.reshape((10, 10))
|
||||
print(self.st.views)
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
assert (len(self.st.views) == 1)
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_different_shape(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
self.st.reshape((2, 5, 2, 5))
|
||||
print(self.st.views)
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
assert (len(self.st.views) == 1)
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_still_complex(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
self.st.reshape((5, 20))
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
|
||||
# Tensor.zeros(2, 4).permute(1,0).reshape(2, 4)
|
||||
# (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4
|
||||
|
||||
@@ -278,7 +278,7 @@ class Kernel:
|
||||
|
||||
def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not(reduceop.src[0].op is UOps.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
if has_cast and not (reduceop.src[0].op is UOps.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
|
||||
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
||||
if mul_op.arg is not BinaryOps.MUL: return None
|
||||
@@ -295,10 +295,10 @@ class Kernel:
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
||||
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
||||
if not (axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
||||
if not(axis < len(axis_choices)): return None
|
||||
if not (axis < len(axis_choices)): return None
|
||||
|
||||
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
||||
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0)
|
||||
@@ -437,7 +437,7 @@ class Kernel:
|
||||
self.upcast()
|
||||
elif opt.op is OptOps.UPCAST: # yellow
|
||||
check(axis < self.first_reduce, "upcast is for non-reduce")
|
||||
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
||||
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
||||
check(amt <= 16, "don't upcast more than 16")
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
|
||||
@@ -347,5 +347,5 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
if cache[s][1] == 1: dfs(s, cache)
|
||||
if cache is None: dfs(x, cache:={})
|
||||
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
||||
cx[2], srcs = True, ('None' if srcfn(x) is None else''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
||||
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
||||
|
||||
@@ -131,7 +131,7 @@ class PTXRenderer(Renderer):
|
||||
|
||||
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
||||
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
|
||||
if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
|
||||
if atype == dtypes.bool: return [f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
|
||||
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
|
||||
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
|
||||
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
|
||||
|
||||
Reference in New Issue
Block a user