EXP+LOG -> EXP2+LOG2 (#954)

* EXP+LOG -> EXP2+LOG2

* update docs
This commit is contained in:
George Hotz
2023-06-08 10:57:31 -07:00
committed by GitHub
parent 666d151f8a
commit df40a9c238
8 changed files with 18 additions and 20 deletions

View File

@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored, prod
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored, prod
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
@@ -12,8 +12,6 @@ render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
class CStyleLanguage(NamedTuple):
kernel_prefix: str = ""
buffer_prefix: str = ""
@@ -48,8 +46,8 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F
return idx, idy
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})",
UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})",
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",

View File

@@ -20,8 +20,8 @@ render_llvm = {
}
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)),
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
@@ -88,11 +88,11 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args.i], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0))
else:
val = bb[-1].load(bb[-1].gep(func.args[args.i], [idx], inbounds=True))
if func_dtypes[args.i] != ir.FloatType():
if func_dtypes[args.i] != ir.FloatType():
if dtypes.is_int(bufs[args.i].dtype):
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].sitofp(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
val = bb[-1].fpext(val, ir.FloatType())
lvars[newvar] = val
if uop == UOps.STORE:
assert args.valid.min == 1, "store must be valid"
@@ -101,7 +101,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
if func_dtypes[0] != ir.FloatType():
if dtypes.is_int(bufs[args.i].dtype):
element = bb[-1].fptoui(element, func_dtypes[0]) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].fptosi(element, func_dtypes[0])
else:
else:
element = bb[-1].fptrunc(element, func_dtypes[0])
bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True))
if uop == UOps.ALU:

View File

@@ -37,14 +37,14 @@ class Relu(Function):
class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.LOG)
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)/math.log(math.e)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.binary_op(BinaryOps.DIV, self.x)
class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.unary_op(UnaryOps.EXP)
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(math.log(math.e)/math.log(2))).unary_op(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@@ -128,7 +128,7 @@ class Pow(Function):
def backward(self, grad_output:LazyBuffer):
return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2)/math.log(math.e))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
class Div(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:

View File

@@ -8,7 +8,7 @@ from tinygrad.runtime.lib import RawBuffer, RawConst
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto(); SIN = auto() # noqa: E702
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702

View File

@@ -27,7 +27,7 @@ def einsum_mulacc(einsum, get_strides, expand):
return mulacc
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP: np.exp, UnaryOps.LOG: np.log, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],

View File

@@ -10,7 +10,7 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.
inverse_type_map = {v:k for k,v in type_map.items()}
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),