mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
index in cstyle (#7328)
* index only in cstyle
* fix prefix dtypes
* fix tests
* global indexing
* Revert "global indexing"
This reverts commit 4d507e8abb.
* fix image
* fix image
* ptx tests
* fix CUDA dtype rendering
This commit is contained in:
@@ -942,10 +942,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
lin = Kernel(sink)
|
||||
lin.linearize()
|
||||
|
||||
assert len(lin.uops) <= 7, "too many uops"
|
||||
a_bufs = [u.op for u in lin.uops[-1].src[2].src]
|
||||
assert a_bufs == [UOps.LOAD, UOps.CONST]
|
||||
assert len(lin.uops) <= 9, "too many uops"
|
||||
|
||||
def test_upcast_cse(self):
|
||||
# when upcasting, within a subtree, there may be common expressions.
|
||||
@@ -989,10 +986,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
# the first store is to lds and can be upcasted
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
assert any(x.op is UOps.DEFINE_LOCAL for x in stores[0].sparents)
|
||||
# the second store is to gds with no upcasts
|
||||
assert stores[1].src[2].dtype == dtypes.float
|
||||
assert stores[1].src[0].op is UOps.DEFINE_GLOBAL
|
||||
assert stores[1].src[-1].dtype == dtypes.float
|
||||
assert any(x.op is UOps.DEFINE_GLOBAL for x in stores[1].sparents)
|
||||
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
@@ -1340,7 +1337,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.src[2].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE
|
||||
assert store.src[-1].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE
|
||||
# # check the children's vins
|
||||
# TODO: src ALU are not the same, should it?
|
||||
# assert barrier.src == tuple(local_stores)
|
||||
@@ -1360,7 +1357,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
#assert stores[0].src[-1].op is not UOps.VECTORIZE
|
||||
|
||||
# the global store doesn't change
|
||||
assert stores[1].src[2].dtype == dtypes.float
|
||||
assert stores[1].src[-1].dtype == dtypes.float
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@@ -1404,11 +1401,11 @@ class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
def count_float4(k, n=4):
|
||||
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(n)]),
|
||||
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(n)]))
|
||||
len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.float.vec(n)]))
|
||||
@staticmethod
|
||||
def count_half4(k):
|
||||
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.half.vec(4)]),
|
||||
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.half.vec(4)]))
|
||||
len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.half.vec(4)]))
|
||||
|
||||
# TODO: express opts below as auto opts
|
||||
|
||||
|
||||
Reference in New Issue
Block a user