diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 1d32dadf1b..2a0903f89b 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -42,10 +42,6 @@ wgsl_matcher = PatternMatcher([ lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None), ]) + extra_pm -type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32", - dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool" } -buffer_map = { **type_map, dtypes.bool: "i32" } - class WGSLRenderer(CStyleLanguage): device = "WEBGPU" global_max = (65535, 65535, 65535) @@ -56,7 +52,9 @@ class WGSLRenderer(CStyleLanguage): barrier = "workgroupBarrier();" code_for_op = {**CStyleLanguage.code_for_op, Ops.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a})"} nan = "nan()" - type_map = type_map + type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32", + dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool" } + buf_map = { **type_map, dtypes.bool: "i32" } string_rewrite = PatternMatcher([ (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "true" if x.arg else "false"), @@ -64,15 +62,11 @@ class WGSLRenderer(CStyleLanguage): if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"), (UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"bitcast({x.arg}u)" if x.arg >= 0x80000000 else f"{x.arg}"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base)}, {x.arg[1]}>;"), - (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"), - (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"), - (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"), - (UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), \ - lambda ctx,b,v,g: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0].dtype)}, {ctx[g]})"), - (UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)), - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), - lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"), - (UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v")), allow_any_len=True),lambda ctx,b,v:\ + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}{['&0xFF','&0xFFFF','',''][x.dtype.itemsize-1]})"), + (UPat.load(UPat.var("b"),UPat.var("v"),UPat.var("g")),lambda ctx,b,v,g:f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"), + (UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)), + (UPat.index(UPat.var("b"), UPat.var("idx")), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"), + (UPat.store(UPat.var('b'), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\ # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1] f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \ else f"{ctx[b]} = {ctx[v]};"), @@ -83,7 +77,7 @@ class WGSLRenderer(CStyleLanguage): def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})" def render_dtype(self, dt:DType, mutable=True) -> str: return "var" def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt.itemsize < 4 else x - def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if dt.itemsize < 4 else buffer_map[dt.base]}" + def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{self.buf_map[dt]}>' if dt.itemsize < 4 else self.buf_map[dt.base]}" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])] if not local_size: local_size = [1] @@ -92,10 +86,9 @@ class WGSLRenderer(CStyleLanguage): kernel[:] = [line for line in kernel if "var" not in line] prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" # trick to obfuscate compiler so that nan is detected properly - prg += "fn is_nan(v:f32) -> bool { return min(v, 1.0) == 1.0 && max(v, -1.0) == -1.0; }\n" - prg += "@group(0) @binding(0)\nvar INFINITY : f32;\n" + prg += "fn is_nan(v:f32) -> bool { return min(v, 1.0) == 1.0 && max(v, -1.0) == -1.0; }\n@group(0) @binding(0)\nvar INFINITY : f32;\n" prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" + f"{'var' if isinstance(dtype, PtrDType) else 'var'}" + - f"{name}:{f'array<{self.render_buf_dt(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs]) + f"{name}:{f'array<{self.render_buf_dt(dtype.base,rw)}>' if isinstance(dtype,PtrDType) else self.buf_map[dtype]};" for name,(dtype,rw) in bufs]) prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3," return prg + "@builtin(local_invocation_id) lindex: vec3) {\n" + "\n".join(kernel) + "\n}" diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 0e61c2553a..fe67491cc6 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -7,8 +7,7 @@ import struct def create_uniform(wgpu_device, val) -> wgpu.GPUBuffer: buf = wgpu_device.create_buffer(size=4, usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST) - if isinstance(val, int): wgpu_device.queue.write_buffer(buf, 0, val.to_bytes(4, "little")) - else: wgpu_device.queue.write_buffer(buf, 0, struct.pack('