mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -37,7 +37,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64)}[buf.dtype] for buf in bufs]
|
||||
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}[buf.dtype] for buf in bufs]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
|
||||
|
||||
# force llvmlite to allow us to add function attribute then add the attribute
|
||||
|
||||
@@ -26,9 +26,13 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
|
||||
return mulacc
|
||||
|
||||
def match_types(x, y):
|
||||
up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
return x.astype(up), y.astype(up)
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to),
|
||||
|
||||
@@ -494,6 +494,10 @@ class Tensor:
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
def ceil(self: Tensor) -> Tensor:
|
||||
b = self.cast(dtypes.int32).contiguous()
|
||||
return (self > 0).where(b+1, b)
|
||||
def floor(self: Tensor) -> Tensor: return self.ceil() - 1
|
||||
|
||||
def __neg__(self): return 0.0-self
|
||||
def sqrt(self): return self.pow(0.5)
|
||||
|
||||
Reference in New Issue
Block a user