Fix cast error in render_load in wgsl (#1956)

* Fix cast error in wgsl

* User render_cast intead of introducing new method

* Make it shorter

* Add back webgpu tests: efficientnet and dtypes
This commit is contained in:
Ahmed Harmouche
2023-10-04 11:29:14 +02:00
committed by GitHub
parent 6a79d4044a
commit fb4d830a2a
4 changed files with 16 additions and 6 deletions

View File

@@ -152,6 +152,7 @@ class TestInt32Dtype(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64")
def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable")
class TestBoolDtype(unittest.TestCase):
def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32])
def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool)