mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add E275 missing-whitespace-after-keyword linting rule (#6149)
requires space after keywords like `assert`, `not`, `return`, `else`
This commit is contained in:
@@ -13,25 +13,25 @@ class TestGC(unittest.TestCase):
|
||||
a = Tensor.rand(4, 4, requires_grad=True)
|
||||
b = Tensor.zeros(4, 4, requires_grad=True)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() > 0)
|
||||
assert (tensors_allocated() > 0)
|
||||
del a,b
|
||||
assert(tensors_allocated() == 1) # one for Tensor._rng_counter
|
||||
assert (tensors_allocated() == 1) # one for Tensor._rng_counter
|
||||
|
||||
def test_gc_complex(self):
|
||||
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
b = Tensor.rand(4, 4, requires_grad=True)
|
||||
assert(tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 3)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() == 5)
|
||||
assert (tensors_allocated() == 5)
|
||||
del b
|
||||
assert(tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 3)
|
||||
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
print(tensors_allocated())
|
||||
(a*b).mean().backward()
|
||||
print(tensors_allocated())
|
||||
assert(tensors_allocated() == 5)
|
||||
assert (tensors_allocated() == 5)
|
||||
del b
|
||||
assert(tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -148,7 +148,7 @@ class TestFetch(unittest.TestCase):
|
||||
self.assertRaises(Exception, fetch, 'http://www.google.com/404', allow_caching=False)
|
||||
|
||||
def test_fetch_small(self):
|
||||
assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0)
|
||||
assert (len(fetch('https://google.com', allow_caching=False).read_bytes())>0)
|
||||
|
||||
def test_fetch_img(self):
|
||||
img = fetch("https://avatars.githubusercontent.com/u/132956020", allow_caching=False)
|
||||
|
||||
@@ -367,38 +367,38 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
self.st = self.st.expand((10, 10))
|
||||
self.st = self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
self.st = self.st.reshape((10, 10))
|
||||
print(self.st.views)
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
assert (len(self.st.views) == 1)
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_different_shape(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
self.st.reshape((2, 5, 2, 5))
|
||||
print(self.st.views)
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
assert (len(self.st.views) == 1)
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_still_complex(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
self.st.reshape((5, 20))
|
||||
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
assert (len(self.st.views) == 2)
|
||||
|
||||
# Tensor.zeros(2, 4).permute(1,0).reshape(2, 4)
|
||||
# (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4
|
||||
|
||||
Reference in New Issue
Block a user