small changes from lowerer (#5266)

This commit is contained in:
George Hotz
2024-07-02 15:03:54 -07:00
committed by GitHub
parent 7be776f9af
commit e53b164e1a
3 changed files with 11 additions and 4 deletions

View File

@@ -728,8 +728,9 @@ class TestLinearizer(unittest.TestCase):
helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
helper(Tensor.arange(-1, -100, -5), max_ops=2)
helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
helper(Tensor.arange(256), max_ops=2)
# NOTE: both of these split the reduce (this just wasn't tracked before)
#helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
#helper(Tensor.arange(256), max_ops=2)
helper(Tensor.arange(255), max_ops=2)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@@ -775,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[-1].dtype == dtypes.float.vec(2) and store.src[-1].op is not UOps.CAST
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.CAST
# check the children's vins
assert barrier.src == tuple(local_stores)
assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1

View File

@@ -765,6 +765,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot)
def test_small_gemm(self):
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y)
def test_9_gemm(self):
helper_test_op([(9,9), (9,9)], lambda x,y: x.matmul(y), lambda x,y: x@y)
def test_small_gemm_padded(self):
helper_test_op([(9,9), (9,9)],
lambda x,y: torch.nn.functional.pad(x, (0,7,0,7)).matmul(torch.nn.functional.pad(y, (0,7,0,7))),
lambda x,y: x.pad(((0,7),(0,7)))@y.pad(((0,7),(0,7))))
def test_small_gemm_range(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8),
np.arange(64,128,dtype=np.float32).reshape(8,8)])

View File

@@ -296,7 +296,7 @@ class Kernel:
bst *= shp[j]
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].real_size())) # real_size ignores the 0's
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])