diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 07b1b9418e..323a43b490 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -34,47 +34,6 @@ Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps] # do not preserve f(0) = 0 UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV} -@dataclass(frozen=True) -class KernelInfo: - local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL) - upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND) - dont_use_locals: bool = False # don't use local indexing - -# **************** ops in python **************** - -def hook_overflow(dv, fxn): - def wfxn(*args): - try: return fxn(*args) - except OverflowError: return dv - return wfxn - -python_alu: Dict[Op, Callable] = { - UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x), - UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), - UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x, - BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, - BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt, - BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, - BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, - TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z} - -def truncate_fp16(x): - try: - x = float(x) - struct.pack("@e", x) - return x - except OverflowError: return math.copysign(math.inf, x) - -truncate: Dict[DType, Callable] = {dtypes.bool: bool, - # TODO: bfloat16 - dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value, - dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value, - dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value, - dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \ - if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value} - -def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) - # the order of these UOps controls the order of the toposort class UOps(Enum): # ops that aren't rendered @@ -226,18 +185,26 @@ class UOp: (UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None) return None, None +@dataclass(frozen=True) +class KernelInfo: + local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL) + upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND) + dont_use_locals: bool = False # don't use local indexing + +# ***** pattern matcher ***** + @dataclass(frozen=True, repr=False) # reuse repr from UOp class NOp(UOp): - name:Optional[str] = None - src:Tuple[NOp, ...] = tuple() - allow_any_len:bool = False + name: Optional[str] = None + src: Tuple[NOp, ...] = tuple() + allow_any_len: bool = False @staticmethod def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name) @staticmethod def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name) def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg) - def compile(self: NOp, name:Optional[str]=None) -> UPat: + def compile(self:NOp, name:Optional[str]=None) -> UPat: return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative() else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len) @@ -309,6 +276,48 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return found return __inner_rewrite(sink) +# ***** ops in python ***** + +def hook_overflow(dv, fxn): + def wfxn(*args): + try: return fxn(*args) + except OverflowError: return dv + return wfxn + +python_alu: Dict[Op, Callable] = { + UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x), + UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), + UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x, + BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, + BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt, + BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, + BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, + TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z} + +def truncate_fp16(x): + try: + x = float(x) + struct.pack("@e", x) + return x + except OverflowError: return math.copysign(math.inf, x) + +truncate: Dict[DType, Callable] = {dtypes.bool: bool, + # TODO: bfloat16 + dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value, + dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value, + dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value, + dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \ + if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value} + +def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) + +def uop_alu_resolve(u:UOp) -> sint: + if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg + if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src))) + raise RuntimeError(f"ALU resolve fail @ {u.op}") + +# ***** uop type spec ***** + def type_verify(uops): for u in uops: uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype @@ -351,10 +360,7 @@ def type_verify(uops): assert src[0].dtype == bd, f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}" assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" -def uop_alu_resolve(u:UOp) -> sint: - if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg - if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src))) - raise RuntimeError(f"ALU resolve fail @ {u.op}") +# ***** uop helpers ***** def print_uops(uops:List[UOp]): for i,u in enumerate(uops):