CMPEQ -> CMPNE and make it safe to pad (#4818)

* CMPNE

* new dataset
This commit is contained in:
chenyu
2024-06-03 18:02:15 -04:00
committed by GitHub
parent 79c7d402ee
commit 3afc914617
14 changed files with 31 additions and 31 deletions

View File

@@ -88,8 +88,8 @@ class Sigmoid(Function):
class Sign(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
return x.e(BinaryOps.CMPEQ, x.const(0)).e(TernaryOps.WHERE, x.const(0),
x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)))
return x.e(BinaryOps.CMPNE, x.const(0)).e(
TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
# backward always return 0 to match torch
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
@@ -99,8 +99,8 @@ class Less(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
class Eq(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
class Neq(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
class Xor(Function):
@@ -166,7 +166,7 @@ class Max(Function):
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPNE, self.ret.expand(self.x.shape)).cast(dtypes.float))
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))