From ff9a89f7147be2c67ae0af9232ffc6a89977e3a8 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Thu, 5 Dec 2024 10:38:05 +0100 Subject: [PATCH] Proper dtypes for input/output of exported WebGPU model (#8053) * Respect input/output dtypes in exported WebGPU model * Add some comments about skipped dtypes --- .github/workflows/test.yml | 2 +- extra/export_model.py | 8 ++++++-- test/testextra/test_export_model.py | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6a0dd78007..97440fa738 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -387,7 +387,7 @@ jobs: WEBGPU=1 WGPU_BACKEND_TYPE=Vulkan python3 -m pytest -n=auto test/test_assign.py test/test_arange.py test/test_const_folding.py test/test_dtype.py \ test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \ test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \ - --durations=20 + test/testextra/test_export_model.py --durations=20 - name: Run process replay tests run: | export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH") diff --git a/extra/export_model.py b/extra/export_model.py index c8a362f693..ac3c679333 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -80,6 +80,8 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in return '\n'.join(cprog) def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]: + def dtype_to_js_type(dtype: DType) -> str: + return "Uint32Array" if dtype in dtypes.uints else "Int32Array" if (dtype in dtypes.ints or dtype == dtypes.bool) else "Float32Array" kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()]) kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements]) create_bind_group_layouts = ",".join([ @@ -92,10 +94,12 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ]) _bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()]) gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)]) - input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)]) + input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names] + output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names] + input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)]) gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)]) outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)]) - output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size/4);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) + output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/4);\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))])) return f""" {web_utils["getTensorBuffer"]} diff --git a/test/testextra/test_export_model.py b/test/testextra/test_export_model.py index a41f16a24e..466f96b1ff 100644 --- a/test/testextra/test_export_model.py +++ b/test/testextra/test_export_model.py @@ -1,6 +1,7 @@ import unittest from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE from tinygrad.tensor import Tensor, Device +from tinygrad import dtypes import json class MockMultiInputModel: @@ -45,6 +46,22 @@ class TextModelExport(unittest.TestCase): for i, exported_output in enumerate(prg["outputs"]): assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" # noqa: E501 +@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Testing WebGPU specific model export behavior") +class TextModelExportWebGPU(unittest.TestCase): + def test_exported_input_output_dtypes(self): + class MyModel: + def forward(self, *inputs): return tuple([inp+2 for inp in inputs]) + model = MyModel() + # [:-1] because "ulong" and "long" is not supported + inputs = [Tensor.randn(2, dtype=dt) for dt in dtypes.uints[:-1] + dtypes.sints[:-1] + (dtypes.bool, dtypes.float)] + prg, _, _, _ = export_model(model, "webgpu", *inputs) + expected_buffer_types = ["Uint"]*len(dtypes.uints[:-1]) + ["Int"]*len(dtypes.sints[:-1]) + ["Int", "Float"] + for i, expected_buffer_type in enumerate(expected_buffer_types): + # test input buffers + self.assertIn(f"new {expected_buffer_type}32Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg) + # test output buffers + self.assertIn(f"const resultBuffer{i} = new {expected_buffer_type}32Array(gpuReadBuffer{i}.size/4);", prg) + self.assertIn(f"resultBuffer{i}.set(new {expected_buffer_type}32Array(gpuReadBuffer{i}.getMappedRange()));", prg) if __name__ == '__main__': unittest.main()