mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -88,8 +88,8 @@ class Sigmoid(Function):
|
||||
|
||||
class Sign(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPEQ, x.const(0)).e(TernaryOps.WHERE, x.const(0),
|
||||
x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)))
|
||||
return x.e(BinaryOps.CMPNE, x.const(0)).e(
|
||||
TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
|
||||
# backward always return 0 to match torch
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
|
||||
|
||||
@@ -99,8 +99,8 @@ class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
||||
|
||||
class Eq(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
|
||||
class Neq(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
||||
|
||||
class Xor(Function):
|
||||
@@ -166,7 +166,7 @@ class Max(Function):
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
|
||||
max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPNE, self.ret.expand(self.x.shape)).cast(dtypes.float))
|
||||
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user