mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
instead of `(1-(x!=max))`, use `(x!=max)!=True`. prep to remove Unary.NEG, also this can be instruction fused later more easily
214 lines
9.1 KiB
Python
214 lines
9.1 KiB
Python
"""This is where the forwards and backwards passes live."""
|
|
import math
|
|
from typing import Tuple, Optional
|
|
from tinygrad.helpers import argsort
|
|
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
|
from tinygrad.tensor import Function
|
|
from tinygrad.lazy import LazyBuffer
|
|
from tinygrad.shape.symbolic import sint
|
|
|
|
class Contiguous(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
|
|
|
|
class ContiguousBackward(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
|
|
|
|
class Cast(Function):
|
|
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
|
|
self.input_dtype, self.bitcast = x.dtype, bitcast
|
|
return x.cast(dtype, bitcast)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
|
|
|
|
# ************* unary ops *************
|
|
|
|
class Neg(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
|
|
|
|
class Reciprocal(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.e(UnaryOps.RECIP)
|
|
return self.ret
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
|
|
|
|
class Sin(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.x = x
|
|
return x.e(UnaryOps.SIN)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
|
|
|
|
# NOTE: maximum(x, 0) behaves differently where x=0
|
|
class Relu(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.e(BinaryOps.MAX, x.const(0))
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
|
|
|
|
class Log(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.x = x
|
|
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
|
|
|
|
class Exp(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
|
|
|
|
class Sqrt(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.e(UnaryOps.SQRT)
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
|
|
|
|
# NOTE: the implicit derivative of sigmoid is not stable
|
|
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
|
# TODO: have the backend automatically find this
|
|
class Sigmoid(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
|
|
|
|
class Sign(Function):
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
return x.e(BinaryOps.CMPNE, x.const(0)).e(
|
|
TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
|
|
# backward always return 0 to match torch
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
|
|
|
|
# ************* binary ops *************
|
|
|
|
class Less(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
|
|
|
class Neq(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
|
|
|
class Xor(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
|
|
|
|
class BitwiseAnd(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.AND, y)
|
|
|
|
class BitwiseOr(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.OR, y)
|
|
|
|
class Threefry(Function):
|
|
def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.THREEFRY, seed)
|
|
|
|
class Add(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
|
return grad_output if self.needs_input_grad[0] else None, \
|
|
grad_output if self.needs_input_grad[1] else None
|
|
|
|
class Mul(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
|
self.x, self.y = x, y
|
|
return x.e(BinaryOps.MUL, y)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
|
return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
|
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
|
|
|
class IDiv(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.IDIV, y)
|
|
|
|
# ************* ternary ops *************
|
|
|
|
class Where(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
|
|
self.x = x
|
|
return self.x.e(TernaryOps.WHERE, y, z)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
|
|
return None, \
|
|
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
|
|
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
|
|
|
|
# ************* reduce ops *************
|
|
|
|
class Sum(Function):
|
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
|
self.input_shape = x.shape
|
|
return x.r(ReduceOps.SUM, axis)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
|
|
|
class Max(Function):
|
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
|
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
# 1s in locations where the max was chosen (can be two locations)
|
|
max_is_1s = self.x.e(BinaryOps.CMPNE, self.ret.expand(self.x.shape)).e(BinaryOps.CMPNE, self.x.const(1).cast(dtypes.bool)).cast(grad_output.dtype)
|
|
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
|
return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
|
|
|
# ************* movement ops *************
|
|
|
|
# NOTE: this is sum in reverse
|
|
class Expand(Function):
|
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
|
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
|
|
return x.expand(shape)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
|
|
|
|
class Reshape(Function):
|
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
|
self.input_shape = x.shape
|
|
return x.reshape(shape)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
|
|
|
|
class Permute(Function):
|
|
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
|
self.input_order = order
|
|
return x.permute(order)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
|
|
|
|
class Pad(Function):
|
|
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
|
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
|
return x.pad(arg)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
|
|
|
|
class Shrink(Function):
|
|
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
|
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
|
return x.shrink(arg)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
|
|
|
|
class Flip(Function):
|
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
|
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
|
|
return x.stride(self.arg)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|