Use matching JS TypedArray for buffer dtype (#8080)

This commit is contained in:
Ahmed Harmouche
2024-12-06 14:52:23 +01:00
committed by GitHub
parent a684d72e55
commit ba35c4138b
2 changed files with 10 additions and 7 deletions

View File

@@ -50,18 +50,20 @@ class TextModelExport(unittest.TestCase):
class TextModelExportWebGPU(unittest.TestCase):
def test_exported_input_output_dtypes(self):
class MyModel:
def forward(self, *inputs): return tuple([inp+2 for inp in inputs])
def forward(self, *inputs): return tuple([(inp+2).cast(inp.dtype) for inp in inputs])
model = MyModel()
# [:-1] because "ulong" and "long" is not supported
inputs = [Tensor.randn(2, dtype=dt) for dt in dtypes.uints[:-1] + dtypes.sints[:-1] + (dtypes.bool, dtypes.float)]
prg, _, _, _ = export_model(model, "webgpu", *inputs)
expected_buffer_types = ["Uint"]*len(dtypes.uints[:-1]) + ["Int"]*len(dtypes.sints[:-1]) + ["Int", "Float"]
for i, expected_buffer_type in enumerate(expected_buffer_types):
dt = inputs[i].dtype
expected_arr_prefix = f"{expected_buffer_type}{dt.itemsize*8}"
# test input buffers
self.assertIn(f"new {expected_buffer_type}32Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg)
self.assertIn(f"new {expected_arr_prefix}Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg)
# test output buffers
self.assertIn(f"const resultBuffer{i} = new {expected_buffer_type}32Array(gpuReadBuffer{i}.size/4);", prg)
self.assertIn(f"resultBuffer{i}.set(new {expected_buffer_type}32Array(gpuReadBuffer{i}.getMappedRange()));", prg)
self.assertIn(f"const resultBuffer{i} = new {expected_arr_prefix}Array(gpuReadBuffer{i}.size/{dt.itemsize});", prg)
self.assertIn(f"resultBuffer{i}.set(new {expected_arr_prefix}Array(gpuReadBuffer{i}.getMappedRange()));", prg)
if __name__ == '__main__':
unittest.main()