simpler webgpu (#2956)

* simpler webgpu

* skip that test
This commit is contained in:
George Hotz
2024-01-01 10:28:59 -08:00
committed by GitHub
parent fea20d71b3
commit 063f465604
4 changed files with 15 additions and 18 deletions

View File

@@ -405,7 +405,8 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2).dtype == dtypes.float64
def test_broadcast_bool(self):
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
if Device.DEFAULT != "WEBGPU":
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
assert (Tensor([0, 1], dtype=dtypes.int) + True).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8) + True).dtype == dtypes.int8
assert (Tensor([0, 1], dtype=dtypes.uint64) + True).dtype == dtypes.uint64

View File

@@ -111,7 +111,8 @@ class LazyBuffer:
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
return ret.cast(dtypes.float32) if (out_dtype == dtypes.bool and self.device == "WEBGPU") else ret
# *** reduce ops ***

View File

@@ -76,8 +76,6 @@ class CStyleLanguage(NamedTuple):
def render_if(self, cond: str): return f"if ({cond}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str: return f"({cond})?({x}):{y}"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
@@ -173,7 +171,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
elif uop == UOps.LOAD:
assert dtype is not None
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
elif uop == UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
@@ -317,6 +316,8 @@ class WGSLLanguage(CStyleLanguage):
elif math.isinf(x): return ("-" if x < 0 else "") + "inf(1.0)"
return f"({super().render_const(x, var_dtype)})"
def render_if(self, cond: str): return f"if (bool({cond})) {{"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
local_size = local_size[::-1] if local_size else [1]
bind_it = iter(range(len(bufs)))
@@ -325,14 +326,7 @@ class WGSLLanguage(CStyleLanguage):
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501
return prg
def render_if(self, cond: str): return f"if (bool({cond})) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select({y}, {x}, bool({cond}))"
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})"
raise NotImplementedError(f"no cast for {var_dtype}")
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};"
WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage())

View File

@@ -672,6 +672,7 @@ class Tensor:
# ***** mlops (unary) *****
def neg(self): return mlops.Neg.apply(self)
def logical_not(self): return self.neg() if self.dtype == dtypes.bool else (1.0-self)
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
@@ -821,12 +822,12 @@ class Tensor:
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
# in webgpu bool cannot be used as a storage buffer type
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool)
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool)
def __ge__(self, x) -> Tensor: return (self<x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self<x)).cast(dtypes.float)
def __le__(self, x) -> Tensor: return (self>x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self>x)).cast(dtypes.float)
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) # type: ignore[override]
def __ne__(self, x) -> Tensor: return (self==x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self==x)).cast(dtypes.float) # type: ignore[override]
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
def __le__(self, x) -> Tensor: return (self>x).logical_not()
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
# ***** functional nn ops *****