mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
hotfix: remove dead code, save lines
This commit is contained in:
@@ -20,8 +20,7 @@ class Cast(Function):
|
||||
self.input_dtype, self.bitcast = x.dtype, bitcast
|
||||
return x.cast(dtype, bitcast)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.cast(self.input_dtype, self.bitcast)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
@@ -55,16 +54,14 @@ class Log(Function):
|
||||
self.x = x
|
||||
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.e(BinaryOps.DIV, self.x)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x)
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.e(BinaryOps.MUL, grad_output)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Sqrt(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
@@ -88,29 +85,23 @@ class Sigmoid(Function):
|
||||
# ************* binary ops *************
|
||||
|
||||
class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPLT, y)
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
|
||||
|
||||
class Eq(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPEQ, y)
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
|
||||
|
||||
class Xor(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.XOR, y)
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.ADD, y)
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x_dtype, self.y_dtype = x.dtype, y.dtype
|
||||
return x.e(BinaryOps.SUB, y)
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
@@ -153,8 +144,7 @@ class Sum(Function):
|
||||
self.input_shape = x.shape
|
||||
return x.r(ReduceOps.SUM, new_shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.expand(self.input_shape)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
@@ -175,45 +165,39 @@ class Expand(Function):
|
||||
self.input_shape = x.shape
|
||||
return x.expand(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.r(ReduceOps.SUM, self.input_shape)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.input_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.reshape(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.reshape(self.input_shape)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_order = order
|
||||
return x.permute(order)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.permute(argsort(self.input_order))
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
|
||||
|
||||
class Pad(Function):
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
||||
return x.pad(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.shrink(self.narg)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
|
||||
|
||||
class Shrink(Function):
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
||||
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
||||
return x.shrink(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.pad(self.narg)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
|
||||
return x.stride(self.arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.stride(self.arg)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|
||||
|
||||
Reference in New Issue
Block a user