Use atomicLoad builtin when loading atomic type (#8084)

This commit is contained in:
Ahmed Harmouche
2024-12-06 15:33:11 +01:00
committed by GitHub
parent 79966fade0
commit fad3eaa35e
2 changed files with 15 additions and 9 deletions

View File

@@ -1,5 +1,5 @@
import os
from extra.export_model import compile_net, jit_model
from extra.export_model import compile_net, jit_model, dtype_to_js_type
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
from tinygrad.tensor import Tensor
@@ -116,10 +116,14 @@ if __name__ == "__main__":
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]
input_buf_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buf_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
exported_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buf_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,_ in enumerate(input_names)])
return f"""\n var {step.name} = function() {{
{kernel_code}
@@ -128,7 +132,7 @@ if __name__ == "__main__":
"setup": async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor[0]);
{bufs}
{exported_bufs}
{gpu_write_bufs}
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
@@ -147,8 +151,8 @@ if __name__ == "__main__":
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
const resultBuffer = new {output_buf_types[0]}(gpuReadBuffer.size/{bufs[output_names[0]][1].itemsize});
resultBuffer.set(new {output_buf_types[0]}(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}}

View File

@@ -67,8 +67,9 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"),
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), lambda ctx,b,v,g: f"select({ctx[v]}, {ctx[b]}, {ctx[g]})"),
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx[b]),
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), \
lambda ctx,b,v,g: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0].dtype)}, {ctx[g]})"),
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
@@ -81,7 +82,8 @@ class WGSLRenderer(CStyleLanguage):
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and dt.itemsize < 4 else buffer_map[dt.base]}"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt.itemsize < 4 else x
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if dt.itemsize < 4 else buffer_map[dt.base]}"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
if not local_size: local_size = [1]