mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 17:38:06 -05:00
* fixed division by zero for fast operations * made et closer to 0 * replace POW llop with SQRT * updated mlops to swap SQRT and POW llops * updated hlops to swap POW and SQRT * added sqrt llop to cpu runtime * added sqrt llop to cstyle codegen * added POW llop to llvm ir codegen * added SQRT llop to torch runtime * moved pow from mlops to hlops * found a better way to do reverse pow * fixed indentation * added SQRT llop to triton * update docs to match new llops * removed POW operator from assembly codegen * added sqrt and rsqrt to pow hlop * rewrote pow function in tensor.py * Adjust tolerance * Adjust for adamw * Reduce for Adam too * removed accidental leftover code * removed all of accidental code * added rsqrt test * removed pow from mlops again it was added back when resolving merge conflicts --------- Co-authored-by: Jacky Lee <jla524@sfu.ca>
220 lines
8.6 KiB
Python
220 lines
8.6 KiB
Python
from typing import Tuple, Optional
|
|
from tinygrad.helpers import argsort, ShapeType
|
|
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps
|
|
from tinygrad.tensor import Function
|
|
from tinygrad.lazy import LazyBuffer
|
|
import math
|
|
|
|
class Contiguous(Function):
|
|
def forward(self, x): return x.contiguous()
|
|
def backward(self, grad_output): return grad_output
|
|
|
|
class Cast(Function):
|
|
__slots__ = "input_dtype"
|
|
def forward(self, x, dtype):
|
|
self.input_dtype = x.dtype
|
|
return x.cast(dtype)
|
|
def backward(self, grad_output):
|
|
return grad_output.cast(self.input_dtype)
|
|
|
|
# ************* unary ops *************
|
|
|
|
class Sin(Function):
|
|
__slots__ = "x"
|
|
def forward(self, x: LazyBuffer) -> LazyBuffer:
|
|
self.x = x
|
|
return x.unary_op(UnaryOps.SIN)
|
|
def backward(self, grad: LazyBuffer) -> LazyBuffer:
|
|
return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad)
|
|
# NOTE: maximum(x, 0) behaves differently where x=0
|
|
class Relu(Function):
|
|
__slots__ = "ret"
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
mask = self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret.binary_op(BinaryOps.CMPEQ, self.ret.const_like(0)))
|
|
return mask.binary_op(BinaryOps.MUL, grad_output)
|
|
|
|
class Log(Function):
|
|
__slots__ = "x"
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.x = x
|
|
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)))
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.binary_op(BinaryOps.DIV, self.x)
|
|
|
|
class Exp(Function):
|
|
__slots__ = "ret"
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return self.ret.binary_op(BinaryOps.MUL, grad_output)
|
|
|
|
class Sqrt(Function):
|
|
__slots__ = "ret"
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.unary_op(UnaryOps.SQRT)
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2)))
|
|
|
|
# NOTE: the implicit derivative of sigmoid is not stable
|
|
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
|
# TODO: have the backend automatically find this
|
|
class Sigmoid(Function):
|
|
__slots__ = "ret"
|
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
|
self.ret = x.const_like(1).binary_op(BinaryOps.DIV, x.const_like(1).binary_op(BinaryOps.ADD, x.binary_op(BinaryOps.MUL, x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2)))
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret)).binary_op(BinaryOps.MUL, grad_output)
|
|
|
|
# ************* reduce ops *************
|
|
|
|
class Sum(Function):
|
|
__slots__ = "input_shape"
|
|
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
|
|
self.input_shape = x.shape
|
|
return x.reduce_op(ReduceOps.SUM, new_shape)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.expand(self.input_shape)
|
|
|
|
class Max(Function):
|
|
__slots__ = "x", "ret"
|
|
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
|
|
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
# 1s in locations where the max was chosen (can be two locations)
|
|
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand(self.x.shape))
|
|
|
|
# sum of locations, averaged
|
|
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
|
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
|
|
|
|
grad_output_expanded = grad_output.expand(self.x.shape)
|
|
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
|
|
|
# ************* binary ops *************
|
|
|
|
class Equal(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
|
return x.binary_op(BinaryOps.CMPEQ, y)
|
|
|
|
class Maximum(Function):
|
|
__slots__ = "x", "y", "ret"
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
|
self.x, self.y = x, y
|
|
self.ret = x.binary_op(BinaryOps.MAX, y)
|
|
return self.ret
|
|
|
|
def backward(self, grad_output:LazyBuffer):
|
|
mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret)
|
|
eq = self.x.binary_op(BinaryOps.CMPEQ, self.y)
|
|
splitter = eq.const_like(2).binary_op(BinaryOps.SUB, eq).binary_op(BinaryOps.DIV, eq.const_like(2))
|
|
|
|
return grad_output.binary_op(BinaryOps.MUL, mask.const_like(1).binary_op(BinaryOps.SUB, mask).binary_op(BinaryOps.ADD, eq)).binary_op(BinaryOps.MUL, splitter) if self.needs_input_grad[0] else None, \
|
|
grad_output.binary_op(BinaryOps.MUL, mask).binary_op(BinaryOps.MUL, splitter) if self.needs_input_grad[1] else None
|
|
|
|
class Add(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
|
return x.binary_op(BinaryOps.ADD, y)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
|
return grad_output if self.needs_input_grad[0] else None, \
|
|
grad_output if self.needs_input_grad[1] else None
|
|
|
|
class Sub(Function):
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
|
return x.binary_op(BinaryOps.SUB, y)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
|
return grad_output if self.needs_input_grad[0] else None, \
|
|
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
|
|
|
|
class Mul(Function):
|
|
__slots__ = 'x', 'y'
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
|
self.x, self.y = x, y
|
|
return x.binary_op(BinaryOps.MUL, y)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
|
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
|
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
|
|
|
class Div(Function):
|
|
__slots__ = 'x', 'y'
|
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
|
self.x, self.y = x, y
|
|
return x.binary_op(BinaryOps.DIV, y)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
|
return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
|
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
|
|
|
|
# ************* movement ops *************
|
|
|
|
# NOTE: this is sum in reverse
|
|
class Expand(Function):
|
|
__slots__ = 'input_shape'
|
|
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
|
self.input_shape = x.shape
|
|
return x.expand(shape)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
|
|
|
|
class Reshape(Function):
|
|
__slots__ = 'input_shape'
|
|
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
|
self.input_shape = x.shape
|
|
return x.reshape(shape)
|
|
|
|
def backward(self, grad_output:LazyBuffer):
|
|
return grad_output.reshape(self.input_shape)
|
|
|
|
class Permute(Function):
|
|
__slots__ = 'input_order'
|
|
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
|
self.input_order = order
|
|
return x.permute(order)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.permute(argsort(self.input_order))
|
|
|
|
class Pad(Function):
|
|
__slots__ = 'narg'
|
|
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
|
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
|
return x.pad(arg)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.shrink(self.narg)
|
|
|
|
class Shrink(Function):
|
|
__slots__ = 'narg'
|
|
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
|
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
|
return x.shrink(arg)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.pad(self.narg)
|
|
|
|
class Flip(Function):
|
|
__slots__ = 'arg'
|
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]):
|
|
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
|
|
return x.stride(self.arg)
|
|
|
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
|
return grad_output.stride(self.arg)
|