mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add CMPEQ (#2931)
* CMPEQ * work * fix onnx * fix round * fix webgpu * prettier * no PADTO in actions
This commit is contained in:
@@ -364,7 +364,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind
|
||||
cond = target == ignore_index
|
||||
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
|
||||
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2))
|
||||
loss = (-mask * x).sum(axis=1) * (1 if weight is None else weight)
|
||||
loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight)
|
||||
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
||||
if reduction == "sum": return loss.sum()
|
||||
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
|
||||
@@ -404,7 +404,7 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
|
||||
if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n)
|
||||
if equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n)
|
||||
if equidistant_case == "round_to_even":
|
||||
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
|
||||
def _and(cond1, cond2): return ((cond1.cast(dtypes.int) + cond2.cast(dtypes.int)) == 2).where(1, 0)
|
||||
x_ceil_fraction = x.ceil()/2
|
||||
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
|
||||
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
|
||||
|
||||
@@ -483,7 +483,7 @@ class Linearizer(Kernel):
|
||||
key = (uop, dtype, vin, arg)
|
||||
|
||||
if uop == UOps.ALU:
|
||||
if arg == BinaryOps.CMPLT: assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
if arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ): assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
if arg == TernaryOps.WHERE: assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
||||
|
||||
if simplify:
|
||||
@@ -539,6 +539,6 @@ class Linearizer(Kernel):
|
||||
if input_acc[off] != acc[off]:
|
||||
acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
else:
|
||||
ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else val[-1].dtype, vin=val, arg=x.op) for val in zip(*values)]
|
||||
ret = [self.uop(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
|
||||
cache[x] = ret
|
||||
return ret
|
||||
|
||||
@@ -13,7 +13,8 @@ actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,
|
||||
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
|
||||
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
|
||||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||
actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)])
|
||||
# TODO: fix PADTO. in addition to STORE, it also needs to gate acc update
|
||||
# actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)])
|
||||
actions += [
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
|
||||
@@ -110,7 +110,7 @@ class LazyBuffer:
|
||||
srcs.append(s)
|
||||
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
|
||||
if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||
out_dtype = srcs[-1].dtype if op != BinaryOps.CMPLT else dtypes.bool
|
||||
out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
# *** reduce ops ***
|
||||
@@ -229,7 +229,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
|
||||
realizes.add(buf.srcs[0].base)
|
||||
for x in buf.srcs: _recurse_lb(x, realizes, allbufs, simple_pads)
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
||||
if buf in realizes or buf.realized: return True
|
||||
# NOTE: this broke to_image_idx and coder with JIT
|
||||
|
||||
@@ -90,6 +90,10 @@ class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPLT, y)
|
||||
|
||||
class Eq(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPEQ, y)
|
||||
|
||||
class Xor(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.XOR, y)
|
||||
@@ -158,7 +162,7 @@ class Max(Function):
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)).cast(self.x.dtype))
|
||||
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype)
|
||||
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ from dataclasses import dataclass
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
# NOTE: rdna3 only has RECIP and not DIV. DIV is on the chopping block
|
||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); XOR = auto() # noqa: E702
|
||||
class BinaryOps(Enum):
|
||||
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
@@ -87,7 +88,7 @@ InterpretedFlopCounter: Dict[Op, Callable] = {
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), # noqa: E501
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501
|
||||
**{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op == BinaryOps.CMPLT else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple):
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"
|
||||
}
|
||||
|
||||
@@ -74,11 +74,9 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for ({self.generic_var_prefix if self.generic_var_prefix else 'int'} {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
|
||||
|
||||
def render_if(self, cond: str):
|
||||
return f"if ({cond}) {{"
|
||||
def render_if(self, cond: str): return f"if ({cond}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
return f"({cond})?({x}):{y}"
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str: return f"({cond})?({x}):{y}"
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
|
||||
@@ -306,13 +304,13 @@ class WGSLLanguage(CStyleLanguage):
|
||||
barrier="workgroupBarrier();"
|
||||
generic_var_prefix = "var "
|
||||
external_local_bufs = True
|
||||
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})",
|
||||
code_for_op = { **CStyleLanguage().code_for_op,
|
||||
BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", BinaryOps.CMPEQ: lambda x,y,dtype: f"f32({x}=={y})",
|
||||
TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},bool({a}))" }
|
||||
# HACK: write bool as f32. remove after elementwise op cast inputs properly
|
||||
# HACK: write bool as f32
|
||||
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "f32"}
|
||||
|
||||
def render_local(self, name: str, dtype:DType, size: int):
|
||||
return f"var<workgroup> {name}: array<{self.type_map[dtype]},{size}>;"
|
||||
def render_local(self, name: str, dtype:DType, size: int): return f"var<workgroup> {name}: array<{self.type_map[dtype]},{size}>;"
|
||||
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
if math.isnan(x): return "nan()"
|
||||
@@ -327,11 +325,9 @@ class WGSLLanguage(CStyleLanguage):
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501
|
||||
return prg
|
||||
|
||||
def render_if(self, cond: str):
|
||||
return f"if (bool({cond})) {{"
|
||||
def render_if(self, cond: str): return f"if (bool({cond})) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
return f"select({y}, {x}, bool({cond}))"
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select({y}, {x}, bool({cond}))"
|
||||
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})"
|
||||
|
||||
@@ -26,6 +26,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
BinaryOps.DIV: lambda builder, x, y, var_dtype:
|
||||
builder.udiv(x, y) if is_bool_or_unsigned(var_dtype) else builder.sdiv(x, y) if dtypes.is_int(var_dtype) else builder.fdiv(x, y, flags=MFLAGS),
|
||||
BinaryOps.CMPLT: lambda builder, x, y, var_dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.CMPEQ: lambda builder, x, y, var_dtype: builder.icmp_unsigned("==", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed("==", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered("==", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
BinaryOps.MOD: lambda builder, x, y, var_dtype:
|
||||
builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y),
|
||||
@@ -162,7 +163,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
||||
else: store_op()
|
||||
if uop == UOps.ALU:
|
||||
if args == TernaryOps.MULACC: lvars[u] = mulacc(bb, lvars[vin[0]], lvars[vin[1]], lvars[vin[2]], vin[0].dtype, dtype)
|
||||
else: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [dtype if args != BinaryOps.CMPLT else vin[0].dtype])
|
||||
else: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin] + [vin[0].dtype if args in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtype])
|
||||
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
|
||||
|
||||
bb[-1].ret_void()
|
||||
|
||||
@@ -28,7 +28,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, BinaryOps.CMPEQ: np.equal,
|
||||
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False), BinaryOps.XOR: np.bitwise_xor,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
|
||||
@@ -31,7 +31,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
|
||||
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt,
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, BinaryOps.CMPEQ: torch.eq,
|
||||
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul,
|
||||
BinaryOps.DIV: lambda x,y: torch.div(x, y).type(x.dtype), BinaryOps.XOR: torch.bitwise_xor,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
|
||||
@@ -823,12 +823,12 @@ class Tensor:
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
||||
|
||||
# in webgpu bool cannot be used as a storage buffer type
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool)
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool)
|
||||
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
|
||||
def __le__(self, x) -> Tensor: return 1.0-(self>x)
|
||||
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore[override]
|
||||
def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore[override]
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool)
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool)
|
||||
def __ge__(self, x) -> Tensor: return (self<x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self<x)).cast(dtypes.float)
|
||||
def __le__(self, x) -> Tensor: return (self>x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self>x)).cast(dtypes.float)
|
||||
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)).cast(dtypes.float if Device.DEFAULT=="WEBGPU" else dtypes.bool) # type: ignore[override]
|
||||
def __ne__(self, x) -> Tensor: return (self==x).neg() if Device.DEFAULT!="WEBGPU" else (1-(self==x)).cast(dtypes.float) # type: ignore[override]
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user