mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Support packed types in smem in webgpu
This commit is contained in:
@@ -315,6 +315,16 @@ class TestLocalAccess(unittest.TestCase):
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking for uchar
|
||||
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
|
||||
def test_local_packed(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(local=True), (), ('smem', 16))
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
|
||||
@@ -68,7 +68,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
|
||||
if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"bitcast<i32>({x.arg}u)" if x.arg >= 0x80000000 else f"{x.arg}"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{type_map[x.dtype.base]}, {x.arg[1]}>;"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base, True)}, {x.arg[1]}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
@@ -78,7 +78,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"),
|
||||
@@ -86,7 +86,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_buf(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}"
|
||||
def render_buf_dt(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}"
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
|
||||
if not local_size: local_size = [1]
|
||||
@@ -99,6 +99,6 @@ class WGSLRenderer(CStyleLanguage):
|
||||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.render_buf(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs])
|
||||
f"{name}:{f'array<{self.render_buf_dt(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
Reference in New Issue
Block a user