mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Proper dtypes for input/output of exported WebGPU model (#8053)
* Respect input/output dtypes in exported WebGPU model * Add some comments about skipped dtypes
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad import dtypes
|
||||
import json
|
||||
|
||||
class MockMultiInputModel:
|
||||
@@ -45,6 +46,22 @@ class TextModelExport(unittest.TestCase):
|
||||
for i, exported_output in enumerate(prg["outputs"]):
|
||||
assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" # noqa: E501
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Testing WebGPU specific model export behavior")
|
||||
class TextModelExportWebGPU(unittest.TestCase):
|
||||
def test_exported_input_output_dtypes(self):
|
||||
class MyModel:
|
||||
def forward(self, *inputs): return tuple([inp+2 for inp in inputs])
|
||||
model = MyModel()
|
||||
# [:-1] because "ulong" and "long" is not supported
|
||||
inputs = [Tensor.randn(2, dtype=dt) for dt in dtypes.uints[:-1] + dtypes.sints[:-1] + (dtypes.bool, dtypes.float)]
|
||||
prg, _, _, _ = export_model(model, "webgpu", *inputs)
|
||||
expected_buffer_types = ["Uint"]*len(dtypes.uints[:-1]) + ["Int"]*len(dtypes.sints[:-1]) + ["Int", "Float"]
|
||||
for i, expected_buffer_type in enumerate(expected_buffer_types):
|
||||
# test input buffers
|
||||
self.assertIn(f"new {expected_buffer_type}32Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg)
|
||||
# test output buffers
|
||||
self.assertIn(f"const resultBuffer{i} = new {expected_buffer_type}32Array(gpuReadBuffer{i}.size/4);", prg)
|
||||
self.assertIn(f"resultBuffer{i}.set(new {expected_buffer_type}32Array(gpuReadBuffer{i}.getMappedRange()));", prg)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user