diff --git a/test/test_uops.py b/test/test_uops.py index d607f27fd5..effcc17c8e 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -315,6 +315,16 @@ class TestLocalAccess(unittest.TestCase): sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) + # NOTE: webgpu specific, since only webgpu performs bitpacking for uchar + @unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type") + def test_local_packed(self): + uops = [] + smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(local=True), (), ('smem', 16)) + st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42))) + barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) + sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) + self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42) + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 9433849cf0..578a7422d2 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -68,7 +68,7 @@ class WGSLRenderer(CStyleLanguage): (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast({x.arg})" \ if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"), (UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"bitcast({x.arg}u)" if x.arg >= 0x80000000 else f"{x.arg}"), - (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{type_map[x.dtype.base]}, {x.arg[1]}>;"), + (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base, True)}, {x.arg[1]}>;"), (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"), (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"), @@ -78,7 +78,7 @@ class WGSLRenderer(CStyleLanguage): lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"), (UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\ # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1] - f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \ + f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \ else f"{ctx[b]} = {ctx[v]};"), # fix nan check: 'a != a -> is_nan()' (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"), @@ -86,7 +86,7 @@ class WGSLRenderer(CStyleLanguage): def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})" def render_dtype(self, dt:DType, mutable=True) -> str: return "var" - def render_buf(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}" + def render_buf_dt(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])] if not local_size: local_size = [1] @@ -99,6 +99,6 @@ class WGSLRenderer(CStyleLanguage): prg += "@group(0) @binding(0)\nvar INFINITY : f32;\n" prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" + f"{'var' if isinstance(dtype, PtrDType) else 'var'}" + - f"{name}:{f'array<{self.render_buf(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs]) + f"{name}:{f'array<{self.render_buf_dt(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs]) prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3," return prg + "@builtin(local_invocation_id) lindex: vec3) {\n" + "\n".join(kernel) + "\n}"