Adds floor/ceil (#989)

* floor ceil impl

* control casting in numpy
This commit is contained in:
Diogo
2023-06-17 13:56:21 -04:00
committed by GitHub
parent 775287ed91
commit d2b837c1d9
6 changed files with 15 additions and 4 deletions

View File

@@ -37,7 +37,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
module = ir.Module(name=__file__)
# create llvm function
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64)}[buf.dtype] for buf in bufs]
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}[buf.dtype] for buf in bufs]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
# force llvmlite to allow us to add function attribute then add the attribute

View File

@@ -26,9 +26,13 @@ def einsum_mulacc(einsum, get_strides, expand):
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
return mulacc
def match_types(x, y):
up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
return x.astype(up), y.astype(up)
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to),

View File

@@ -494,6 +494,10 @@ class Tensor:
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
# ***** math functions (unary) *****
def ceil(self: Tensor) -> Tensor:
b = self.cast(dtypes.int32).contiguous()
return (self > 0).where(b+1, b)
def floor(self: Tensor) -> Tensor: return self.ceil() - 1
def __neg__(self): return 0.0-self
def sqrt(self): return self.pow(0.5)