Support packed types in smem in webgpu

This commit is contained in:
Ahmed Harmouche
2024-12-02 09:59:05 +01:00
parent 61b2cac507
commit 1ea0925744
2 changed files with 14 additions and 4 deletions

View File

@@ -315,6 +315,16 @@ class TestLocalAccess(unittest.TestCase):
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking for uchar
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
def test_local_packed(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(local=True), (), ('smem', 16))
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_indirect(self):
uops = []