mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored, prod
|
||||
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored, prod
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -12,8 +12,6 @@ render_cl = render_python.copy()
|
||||
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
|
||||
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
||||
|
||||
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
kernel_prefix: str = ""
|
||||
buffer_prefix: str = ""
|
||||
@@ -48,8 +46,8 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F
|
||||
return idx, idy
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})",
|
||||
UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})",
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x: f"sin({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
|
||||
@@ -20,8 +20,8 @@ render_llvm = {
|
||||
}
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
||||
@@ -88,11 +88,11 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
||||
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args.i], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0))
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(func.args[args.i], [idx], inbounds=True))
|
||||
if func_dtypes[args.i] != ir.FloatType():
|
||||
if func_dtypes[args.i] != ir.FloatType():
|
||||
if dtypes.is_int(bufs[args.i].dtype):
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].sitofp(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
lvars[newvar] = val
|
||||
if uop == UOps.STORE:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
@@ -101,7 +101,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
||||
if func_dtypes[0] != ir.FloatType():
|
||||
if dtypes.is_int(bufs[args.i].dtype):
|
||||
element = bb[-1].fptoui(element, func_dtypes[0]) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].fptosi(element, func_dtypes[0])
|
||||
else:
|
||||
else:
|
||||
element = bb[-1].fptrunc(element, func_dtypes[0])
|
||||
bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True))
|
||||
if uop == UOps.ALU:
|
||||
|
||||
@@ -37,14 +37,14 @@ class Relu(Function):
|
||||
class Log(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.LOG)
|
||||
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)/math.log(math.e)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.x)
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.unary_op(UnaryOps.EXP)
|
||||
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(math.log(math.e)/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
@@ -128,7 +128,7 @@ class Pow(Function):
|
||||
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
|
||||
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2)/math.log(math.e))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.runtime.lib import RawBuffer, RawConst
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto(); SIN = auto() # noqa: E702
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
|
||||
@@ -27,7 +27,7 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||
return mulacc
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP: np.exp, UnaryOps.LOG: np.log, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
|
||||
@@ -10,7 +10,7 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
|
||||
Reference in New Issue
Block a user