mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Use matching JS TypedArray for buffer dtype (#8080)
This commit is contained in:
@@ -79,9 +79,10 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
|
||||
def dtype_to_js_type(dtype: DType) -> str:
|
||||
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
||||
def dtype_to_js_type(dtype: DType) -> str:
|
||||
return "Uint32Array" if dtype in dtypes.uints else "Int32Array" if (dtype in dtypes.sints or dtype == dtypes.bool) else "Float32Array"
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
create_bind_group_layouts = ",".join([
|
||||
@@ -99,7 +100,7 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
|
||||
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/4);\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
{web_utils["getTensorBuffer"]}
|
||||
|
||||
@@ -50,18 +50,20 @@ class TextModelExport(unittest.TestCase):
|
||||
class TextModelExportWebGPU(unittest.TestCase):
|
||||
def test_exported_input_output_dtypes(self):
|
||||
class MyModel:
|
||||
def forward(self, *inputs): return tuple([inp+2 for inp in inputs])
|
||||
def forward(self, *inputs): return tuple([(inp+2).cast(inp.dtype) for inp in inputs])
|
||||
model = MyModel()
|
||||
# [:-1] because "ulong" and "long" is not supported
|
||||
inputs = [Tensor.randn(2, dtype=dt) for dt in dtypes.uints[:-1] + dtypes.sints[:-1] + (dtypes.bool, dtypes.float)]
|
||||
prg, _, _, _ = export_model(model, "webgpu", *inputs)
|
||||
expected_buffer_types = ["Uint"]*len(dtypes.uints[:-1]) + ["Int"]*len(dtypes.sints[:-1]) + ["Int", "Float"]
|
||||
for i, expected_buffer_type in enumerate(expected_buffer_types):
|
||||
dt = inputs[i].dtype
|
||||
expected_arr_prefix = f"{expected_buffer_type}{dt.itemsize*8}"
|
||||
# test input buffers
|
||||
self.assertIn(f"new {expected_buffer_type}32Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg)
|
||||
self.assertIn(f"new {expected_arr_prefix}Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg)
|
||||
# test output buffers
|
||||
self.assertIn(f"const resultBuffer{i} = new {expected_buffer_type}32Array(gpuReadBuffer{i}.size/4);", prg)
|
||||
self.assertIn(f"resultBuffer{i}.set(new {expected_buffer_type}32Array(gpuReadBuffer{i}.getMappedRange()));", prg)
|
||||
self.assertIn(f"const resultBuffer{i} = new {expected_arr_prefix}Array(gpuReadBuffer{i}.size/{dt.itemsize});", prg)
|
||||
self.assertIn(f"resultBuffer{i}.set(new {expected_arr_prefix}Array(gpuReadBuffer{i}.getMappedRange()));", prg)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user