mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
Remove POW llop and add SQRT llop (#1104)
* fixed division by zero for fast operations * made et closer to 0 * replace POW llop with SQRT * updated mlops to swap SQRT and POW llops * updated hlops to swap POW and SQRT * added sqrt llop to cpu runtime * added sqrt llop to cstyle codegen * added POW llop to llvm ir codegen * added SQRT llop to torch runtime * moved pow from mlops to hlops * found a better way to do reverse pow * fixed indentation * added SQRT llop to triton * update docs to match new llops * removed POW operator from assembly codegen * added sqrt and rsqrt to pow hlop * rewrote pow function in tensor.py * Adjust tolerance * Adjust for adamw * Reduce for Adam too * removed accidental leftover code * removed all of accidental code * added rsqrt test * removed pow from mlops again it was added back when resolving merge conflicts --------- Co-authored-by: Jacky Lee <jla524@sfu.ca>
This commit is contained in:
@@ -143,13 +143,6 @@ class AssemblyCodegen(Linearizer):
|
||||
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
|
||||
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.POW:
|
||||
# TODO: add UnaryOps.SQRT
|
||||
tmp = newreg((newvar, "exp_a"))
|
||||
tmp2 = newreg((newvar, "exp_a_times_b"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp2], UnaryOps.EXP2))
|
||||
elif args == BinaryOps.DIV and self.no_div:
|
||||
tmp = newreg((newvar, "rcp"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
|
||||
|
||||
@@ -50,9 +50,10 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x: f"sin({x})",
|
||||
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
||||
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
|
||||
}
|
||||
|
||||
|
||||
@@ -22,11 +22,11 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
|
||||
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)),
|
||||
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()),
|
||||
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)),
|
||||
FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)),
|
||||
|
||||
@@ -55,6 +55,15 @@ class Exp(Function):
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Sqrt(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.unary_op(UnaryOps.SQRT)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2)))
|
||||
|
||||
# NOTE: the implicit derivative of sigmoid is not stable
|
||||
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
||||
# TODO: have the backend automatically find this
|
||||
@@ -142,16 +151,6 @@ class Mul(Function):
|
||||
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
||||
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Pow(Function):
|
||||
__slots__ = 'x', 'y', 'ret'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
__slots__ = 'x', 'y'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
@@ -217,4 +216,4 @@ class Flip(Function):
|
||||
return x.stride(self.arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.stride(self.arg)
|
||||
return grad_output.stride(self.arg)
|
||||
|
||||
@@ -12,8 +12,8 @@ if TYPE_CHECKING:
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); RECIP = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
@@ -10,7 +10,7 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
|
||||
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
|
||||
|
||||
base_fxn_for_op: Dict[Op, Callable] = {
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
||||
@@ -32,7 +32,7 @@ def match_types(x, y):
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to),
|
||||
|
||||
@@ -10,7 +10,7 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
|
||||
@@ -487,6 +487,8 @@ class Tensor:
|
||||
def relu(self): return mlops.Relu.apply(self)
|
||||
def sigmoid(self): return mlops.Sigmoid.apply(self)
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
def sqrt(self): return mlops.Sqrt.apply(self)
|
||||
def rsqrt(self): return (1/self).sqrt()
|
||||
def cos(self): return ((pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
||||
@@ -504,8 +506,6 @@ class Tensor:
|
||||
return (self < b).where(b-1, b)
|
||||
|
||||
def __neg__(self): return 0.0-self
|
||||
def sqrt(self): return self.pow(0.5)
|
||||
def rsqrt(self): return self.pow(-0.5)
|
||||
def square(self): return self*self
|
||||
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
|
||||
def abs(self): return self.relu() + (-self).relu()
|
||||
@@ -552,13 +552,15 @@ class Tensor:
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
if x.__class__ is not Tensor and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return (1.0/self).pow(-x)
|
||||
if x == 2.0: return self*self
|
||||
if x == -1.0: return 1/self
|
||||
return self._broadcasted(mlops.Pow, x, reverse) if x.__class__ is Tensor or x != 1.0 or reverse else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
|
||||
if x == 1.0: return self
|
||||
if x == 0.5: return self.sqrt()
|
||||
return self.log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(x)).exp()
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
|
||||
|
||||
Reference in New Issue
Block a user