index in cstyle (#7328)

* index only in cstyle

* fix prefix dtypes

* fix tests

* global indexing

* Revert "global indexing"

This reverts commit 4d507e8abb.

* fix image

* fix image

* ptx tests

* fix CUDA dtype rendering
This commit is contained in:
George Hotz
2024-10-29 12:06:26 +07:00
committed by GitHub
parent f55c3dcff8
commit 4cb236a495
6 changed files with 74 additions and 53 deletions

View File

@@ -942,10 +942,7 @@ class TestLinearizer(unittest.TestCase):
sink = UOp(UOps.SINK, src=(store,))
lin = Kernel(sink)
lin.linearize()
assert len(lin.uops) <= 7, "too many uops"
a_bufs = [u.op for u in lin.uops[-1].src[2].src]
assert a_bufs == [UOps.LOAD, UOps.CONST]
assert len(lin.uops) <= 9, "too many uops"
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.
@@ -989,10 +986,10 @@ class TestLinearizer(unittest.TestCase):
# the first store is to lds and can be upcasted
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[0].op is UOps.DEFINE_LOCAL
assert any(x.op is UOps.DEFINE_LOCAL for x in stores[0].sparents)
# the second store is to gds with no upcasts
assert stores[1].src[2].dtype == dtypes.float
assert stores[1].src[0].op is UOps.DEFINE_GLOBAL
assert stores[1].src[-1].dtype == dtypes.float
assert any(x.op is UOps.DEFINE_GLOBAL for x in stores[1].sparents)
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
@@ -1340,7 +1337,7 @@ class TestLinearizer(unittest.TestCase):
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[2].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE
assert store.src[-1].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
@@ -1360,7 +1357,7 @@ class TestLinearizer(unittest.TestCase):
#assert stores[0].src[-1].op is not UOps.VECTORIZE
# the global store doesn't change
assert stores[1].src[2].dtype == dtypes.float
assert stores[1].src[-1].dtype == dtypes.float
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@@ -1404,11 +1401,11 @@ class TestFloat4(unittest.TestCase):
@staticmethod
def count_float4(k, n=4):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(n)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(n)]))
len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.float.vec(n)]))
@staticmethod
def count_half4(k):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.half.vec(4)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.half.vec(4)]))
len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.half.vec(4)]))
# TODO: express opts below as auto opts