diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 04f87e3bf6..724045788a 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -364,7 +364,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind cond = target == ignore_index weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1) mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2)) - loss = (-mask * x).sum(axis=1) * (1 if weight is None else weight) + loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight) if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum() if reduction == "sum": return loss.sum() return loss.reshape(t_shape) if len(i_shape) != 3 else loss @@ -404,7 +404,7 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor: if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n) if equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n) if equidistant_case == "round_to_even": - def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0) + def _and(cond1, cond2): return ((cond1.cast(dtypes.int) + cond2.cast(dtypes.int)) == 2).where(1, 0) x_ceil_fraction = x.ceil()/2 cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction x = (_and(x == b, cond_ceil_even)).where(x+1-n, x) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 77a7502d50..cd0c25fb32 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -483,7 +483,7 @@ class Linearizer(Kernel): key = (uop, dtype, vin, arg) if uop == UOps.ALU: - if arg == BinaryOps.CMPLT: assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}" + if arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ): assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}" if arg == TernaryOps.WHERE: assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}" if simplify: @@ -539,6 +539,6 @@ class Linearizer(Kernel): if input_acc[off] != acc[off]: acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx)) else: - ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else val[-1].dtype, vin=val, arg=x.op) for val in zip(*values)] + ret = [self.uop(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)] cache[x] = ret return ret diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 4e97354136..c0eefd0ed4 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -13,7 +13,8 @@ actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3, actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)]) actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)]) actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)]) -actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)]) +# TODO: fix PADTO. in addition to STORE, it also needs to gate acc update +# actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)]) actions += [ Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8), diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index e9bcaba883..4d1ba527be 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -110,7 +110,7 @@ class LazyBuffer: srcs.append(s) assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}" if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" - out_dtype = srcs[-1].dtype if op != BinaryOps.CMPLT else dtypes.bool + out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs)) # *** reduce ops *** @@ -229,7 +229,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe realizes.add(buf.srcs[0].base) for x in buf.srcs: _recurse_lb(x, realizes, allbufs, simple_pads) -UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} +UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool: if buf in realizes or buf.realized: return True # NOTE: this broke to_image_idx and coder with JIT diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 8f4cc62800..cdbdc782a2 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -90,6 +90,10 @@ class Less(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y) +class Eq(Function): + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: + return x.e(BinaryOps.CMPEQ, y) + class Xor(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y) @@ -158,7 +162,7 @@ class Max(Function): def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)).cast(self.x.dtype)) + max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype) div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d3c1f5b70d..112573e553 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -11,7 +11,8 @@ from dataclasses import dataclass # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars # NOTE: rdna3 only has RECIP and not DIV. DIV is on the chopping block class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); XOR = auto() # noqa: E702 +class BinaryOps(Enum): + ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 @@ -87,7 +88,7 @@ InterpretedFlopCounter: Dict[Op, Callable] = { BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), # noqa: E501 UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501 - **{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op == BinaryOps.CMPLT else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 + **{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 **{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 7a44dc6733..dbafe88e7e 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple): BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", - BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", + BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" } @@ -74,11 +74,9 @@ class CStyleLanguage(NamedTuple): def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str: return f"for ({self.generic_var_prefix if self.generic_var_prefix else 'int'} {expr} = {_min}; {expr} < {_max}; {expr}++) {{" - def render_if(self, cond: str): - return f"if ({cond}) {{" + def render_if(self, cond: str): return f"if ({cond}) {{" - def render_conditional(self, cond: str, x:str, y:str) -> str: - return f"({cond})?({x}):{y}" + def render_conditional(self, cond: str, x:str, y:str) -> str: return f"({cond})?({x}):{y}" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501 @@ -306,13 +304,13 @@ class WGSLLanguage(CStyleLanguage): barrier="workgroupBarrier();" generic_var_prefix = "var " external_local_bufs = True - code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", + code_for_op = { **CStyleLanguage().code_for_op, + BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", BinaryOps.CMPEQ: lambda x,y,dtype: f"f32({x}=={y})", TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},bool({a}))" } - # HACK: write bool as f32. remove after elementwise op cast inputs properly + # HACK: write bool as f32 type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "f32"} - def render_local(self, name: str, dtype:DType, size: int): - return f"var {name}: array<{self.type_map[dtype]},{size}>;" + def render_local(self, name: str, dtype:DType, size: int): return f"var {name}: array<{self.type_map[dtype]},{size}>;" def render_const(self, x:Union[float,int], var_dtype) -> str: if math.isnan(x): return "nan()" @@ -327,11 +325,9 @@ class WGSLLanguage(CStyleLanguage): prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501 return prg - def render_if(self, cond: str): - return f"if (bool({cond})) {{" + def render_if(self, cond: str): return f"if (bool({cond})) {{" - def render_conditional(self, cond:str, x:str, y:str) -> str: - return f"select({y}, {x}, bool({cond}))" + def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select({y}, {x}, bool({cond}))" def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})" diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index c39e7978fe..88c2499ad2 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -26,6 +26,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.DIV: lambda builder, x, y, var_dtype: builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=MFLAGS), BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501 + BinaryOps.CMPEQ: lambda builder, x, y, var_dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501 BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), @@ -162,7 +163,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: else: store_op() if uop == UOps.ALU: if args == TernaryOps.MULACC: lvars[u] = mulacc(bb, lvars[vin[0]], lvars[vin[1]], lvars[vin[2]], vin[0].dtype, dtype) - else: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype]) + else: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [vin[0].dtype if args in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtype]) if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1]) bb[-1].ret_void() diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index fddd37f0c9..730a173e5b 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -28,7 +28,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = { UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt, UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), - BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, + BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, BinaryOps.CMPEQ: np.equal, BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply, BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False), BinaryOps.XOR: np.bitwise_xor, TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to), diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 3d02641c61..b6d9612378 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -31,7 +31,7 @@ torch_fxn_for_op: Dict[Op, Callable] = { UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt, UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x), - BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, + BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, BinaryOps.CMPEQ: torch.eq, BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul, BinaryOps.DIV: lambda x,y: torch.div(x, y).type(x.dtype), BinaryOps.XOR: torch.bitwise_xor, TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fa906fc945..fe5a25feb1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -823,12 +823,12 @@ class Tensor: def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x)) # in webgpu bool cannot be used as a storage buffer type - def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool) - def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool) - def __ge__(self, x) -> Tensor: return 1.0-(self Tensor: return 1.0-(self>x) - def __ne__(self, x) -> Tensor: return (selfx) # type: ignore[override] - def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore[override] + def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) + def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) + def __ge__(self, x) -> Tensor: return (self Tensor: return (self>x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self>x)).cast(dtypes.float) + def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) # type: ignore[override] + def __ne__(self, x) -> Tensor: return (self==x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self==x)).cast(dtypes.float) # type: ignore[override] # ***** functional nn ops *****