mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
@@ -405,7 +405,8 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2).dtype == dtypes.float64
|
||||
|
||||
def test_broadcast_bool(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
|
||||
if Device.DEFAULT != "WEBGPU":
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
|
||||
assert (Tensor([0, 1], dtype=dtypes.int) + True).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8) + True).dtype == dtypes.int8
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64) + True).dtype == dtypes.uint64
|
||||
|
||||
@@ -111,7 +111,8 @@ class LazyBuffer:
|
||||
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
|
||||
if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||
out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
return ret.cast(dtypes.float32) if (out_dtype == dtypes.bool and self.device == "WEBGPU") else ret
|
||||
|
||||
# *** reduce ops ***
|
||||
|
||||
|
||||
@@ -76,8 +76,6 @@ class CStyleLanguage(NamedTuple):
|
||||
|
||||
def render_if(self, cond: str): return f"if ({cond}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str: return f"({cond})?({x}):{y}"
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
@@ -173,7 +171,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
||||
elif uop == UOps.LOAD:
|
||||
assert dtype is not None
|
||||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
|
||||
if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
|
||||
elif uop == UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
@@ -317,6 +316,8 @@ class WGSLLanguage(CStyleLanguage):
|
||||
elif math.isinf(x): return ("-" if x < 0 else "") + "inf(1.0)"
|
||||
return f"({super().render_const(x, var_dtype)})"
|
||||
|
||||
def render_if(self, cond: str): return f"if (bool({cond})) {{"
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
local_size = local_size[::-1] if local_size else [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
@@ -325,14 +326,7 @@ class WGSLLanguage(CStyleLanguage):
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501
|
||||
return prg
|
||||
|
||||
def render_if(self, cond: str): return f"if (bool({cond})) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select({y}, {x}, bool({cond}))"
|
||||
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
||||
return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};"
|
||||
WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage())
|
||||
|
||||
@@ -672,6 +672,7 @@ class Tensor:
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
def neg(self): return mlops.Neg.apply(self)
|
||||
def logical_not(self): return self.neg() if self.dtype == dtypes.bool else (1.0-self)
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
|
||||
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
|
||||
@@ -821,12 +822,12 @@ class Tensor:
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
||||
|
||||
# in webgpu bool cannot be used as a storage buffer type
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool)
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool)
|
||||
def __ge__(self, x) -> Tensor: return (self<x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self<x)).cast(dtypes.float)
|
||||
def __le__(self, x) -> Tensor: return (self>x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self>x)).cast(dtypes.float)
|
||||
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) # type: ignore[override]
|
||||
def __ne__(self, x) -> Tensor: return (self==x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self==x)).cast(dtypes.float) # type: ignore[override]
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
|
||||
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
|
||||
def __le__(self, x) -> Tensor: return (self>x).logical_not()
|
||||
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
|
||||
def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user