mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
small changes from lowerer (#5266)
This commit is contained in:
@@ -728,8 +728,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
|
||||
helper(Tensor.arange(-1, -100, -5), max_ops=2)
|
||||
helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
|
||||
helper(Tensor.arange(256), max_ops=2)
|
||||
# NOTE: both of these split the reduce (this just wasn't tracked before)
|
||||
#helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
|
||||
#helper(Tensor.arange(256), max_ops=2)
|
||||
helper(Tensor.arange(255), max_ops=2)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@@ -775,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.src[-1].dtype == dtypes.float.vec(2) and store.src[-1].op is not UOps.CAST
|
||||
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.CAST
|
||||
# check the children's vins
|
||||
assert barrier.src == tuple(local_stores)
|
||||
assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1
|
||||
|
||||
@@ -765,6 +765,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot)
|
||||
def test_small_gemm(self):
|
||||
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y)
|
||||
def test_9_gemm(self):
|
||||
helper_test_op([(9,9), (9,9)], lambda x,y: x.matmul(y), lambda x,y: x@y)
|
||||
def test_small_gemm_padded(self):
|
||||
helper_test_op([(9,9), (9,9)],
|
||||
lambda x,y: torch.nn.functional.pad(x, (0,7,0,7)).matmul(torch.nn.functional.pad(y, (0,7,0,7))),
|
||||
lambda x,y: x.pad(((0,7),(0,7)))@y.pad(((0,7),(0,7))))
|
||||
def test_small_gemm_range(self):
|
||||
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8),
|
||||
np.arange(64,128,dtype=np.float32).reshape(8,8)])
|
||||
|
||||
@@ -296,7 +296,7 @@ class Kernel:
|
||||
bst *= shp[j]
|
||||
|
||||
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].real_size())) # real_size ignores the 0's
|
||||
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
||||
self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user