fix neg logical_not inconsistencies (#3222)

* try

* test: add logical_not tests

* gah im retarded, but this doesn't match types for const()

* fix: can't we jsut do this?

* big change: I don't actually know what I'm doing

* WOOO IM JUST CHANGING EVERYTHING WOW probably gon revert later

* BYE BYE noqa: E501

* fix: less lines and add test

* fix: rm 2 redundant tests

* fix: eq with False so we don't unintentionally implicit upcast, but it's bool anyways so w/e
This commit is contained in:
geohotstan
2024-01-25 00:48:40 +08:00
committed by GitHub
parent e2e4632aea
commit 842053873d
9 changed files with 14 additions and 12 deletions

View File

@@ -156,7 +156,7 @@ class PTXLanguage(AssemblyLanguage):
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
asm_for_op = {
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if dt == dtypes.bool else f"neg.{name} {d}, {a};",
UnaryOps.NEG: lambda d,a,dt,name: f"neg.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};",
UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",

View File

@@ -63,7 +63,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
UnaryOps.NEG: lambda x,dtype: f"-{x}",
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",

View File

@@ -290,8 +290,11 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: 2-x)
def test_neg(self):
helper_test_op([(45,65)], lambda x: -x)
helper_test_op([()], lambda x: -x)
helper_test_op([(45,65)], lambda x: x.neg())
helper_test_op([()], lambda x: x.neg())
def test_logical_not(self):
helper_test_op(None, torch.logical_not, Tensor.logical_not, vals=[[True, False, True]], forward_only=True)
helper_test_op(None, torch.logical_not, Tensor.logical_not, vals=[[1.,2.,0.,0.5]], forward_only=True)
def test_mul(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)

View File

@@ -109,6 +109,7 @@ class LazyBuffer:
assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))

View File

@@ -27,7 +27,7 @@ class CStyleLanguage(NamedTuple):
launch_bounds: bool = False
type_map: Dict[DType, str] = {}
code_for_op: Dict = {
UnaryOps.NEG: lambda x,dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
UnaryOps.NEG: lambda x,dtype: f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",

View File

@@ -9,7 +9,7 @@ MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder, x, var_dtype: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if var_dtype == dtypes.bool else builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), # noqa: E501
UnaryOps.NEG: lambda builder, x, var_dtype: builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS),
UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),

View File

@@ -25,9 +25,8 @@ def as_strided(x, arg):
numpy_fxn_for_op: Dict[Op, Callable] = {
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt,
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt, UnaryOps.NEG: np.negative,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, BinaryOps.CMPEQ: np.equal, BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract,
BinaryOps.MUL: np.multiply, BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False), BinaryOps.XOR: np.bitwise_xor,
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,

View File

@@ -25,9 +25,8 @@ def as_strided(x, arg):
torch_fxn_for_op: Dict[Op, Callable] = {
BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt,
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt, UnaryOps.NEG: torch.neg,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul, BinaryOps.DIV: lambda x,y: torch.div(x, y).type(x.dtype),
BinaryOps.XOR: torch.bitwise_xor, BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, BinaryOps.CMPEQ: torch.eq,
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,

View File

@@ -722,7 +722,7 @@ class Tensor:
# ***** mlops (unary) *****
def neg(self): return mlops.Neg.apply(self)
def logical_not(self): return self.neg() if self.dtype == dtypes.bool else (1.0-self)
def logical_not(self): return mlops.Eq.apply(*self._broadcasted(False))
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
@@ -846,7 +846,7 @@ class Tensor:
# ***** op wrappers (wasted lines to make the typechecker happy) *****
def __neg__(self) -> Tensor: return self.neg()
def __neg__(self) -> Tensor: return self.neg() if self.dtype != dtypes.bool else self.logical_not()
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)