diff --git a/test/test_dtype.py b/test/test_dtype.py index 1c459ff9ca..070fe80f97 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -405,7 +405,8 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2).dtype == dtypes.float64 def test_broadcast_bool(self): - assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool + if Device.DEFAULT != "WEBGPU": + assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool assert (Tensor([0, 1], dtype=dtypes.int) + True).dtype == dtypes.int32 assert (Tensor([0, 1], dtype=dtypes.int8) + True).dtype == dtypes.int8 assert (Tensor([0, 1], dtype=dtypes.uint64) + True).dtype == dtypes.uint64 diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 4d1ba527be..78abc9a630 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -111,7 +111,8 @@ class LazyBuffer: assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}" if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool - return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs)) + ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs)) + return ret.cast(dtypes.float32) if (out_dtype == dtypes.bool and self.device == "WEBGPU") else ret # *** reduce ops *** diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index dbafe88e7e..1d9d5bf6c1 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -76,8 +76,6 @@ class CStyleLanguage(NamedTuple): def render_if(self, cond: str): return f"if ({cond}) {{" - def render_conditional(self, cond: str, x:str, y:str) -> str: return f"({cond})?({x}):{y}" - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501 buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else @@ -173,7 +171,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu elif uop == UOps.LOAD: assert dtype is not None val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL) - if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) + # NOTE: this relies on the load not happening if it's in the unselected branch + if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype) kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};") elif uop == UOps.PHI: kk(f"{r[vin[0]]} = {r[vin[1]]};") @@ -317,6 +316,8 @@ class WGSLLanguage(CStyleLanguage): elif math.isinf(x): return ("-" if x < 0 else "") + "inf(1.0)" return f"({super().render_const(x, var_dtype)})" + def render_if(self, cond: str): return f"if (bool({cond})) {{" + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: local_size = local_size[::-1] if local_size else [1] bind_it = iter(range(len(bufs))) @@ -325,14 +326,7 @@ class WGSLLanguage(CStyleLanguage): prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501 return prg - def render_if(self, cond: str): return f"if (bool({cond})) {{" - - def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select({y}, {x}, bool({cond}))" - def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})" raise NotImplementedError(f"no cast for {var_dtype}") - - def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: - return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};" WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage()) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index dd54c2e4d0..de05215666 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -672,6 +672,7 @@ class Tensor: # ***** mlops (unary) ***** def neg(self): return mlops.Neg.apply(self) + def logical_not(self): return self.neg() if self.dtype == dtypes.bool else (1.0-self) def contiguous(self): return mlops.Contiguous.apply(self) def contiguous_backward(self): return mlops.ContiguousBackward.apply(self) def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype))) @@ -821,12 +822,12 @@ class Tensor: def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x)) # in webgpu bool cannot be used as a storage buffer type - def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) - def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) - def __ge__(self, x) -> Tensor: return (self Tensor: return (self>x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self>x)).cast(dtypes.float) - def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) # type: ignore[override] - def __ne__(self, x) -> Tensor: return (self==x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self==x)).cast(dtypes.float) # type: ignore[override] + def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)) + def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)) + def __ge__(self, x) -> Tensor: return (self Tensor: return (self>x).logical_not() + def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override] + def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override] # ***** functional nn ops *****