diff --git a/extra/ops.py b/extra/ops.py deleted file mode 100644 index 76367a12c4..0000000000 --- a/extra/ops.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import annotations -from typing import Dict, Union, Tuple, Any, List, cast -import functools, hashlib -from enum import Enum, auto -from dataclasses import dataclass -from tinygrad.helpers import dedup, prod -from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps, pretty_print -from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType -from tinygrad.ops import Variable, sint -from tinygrad.shape.shapetracker import ShapeTracker - -# these ops are deleted after AST is UOp -class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 -class MetaOps(Enum): KERNEL = auto(); -Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps] - -@dataclass(frozen=True) -class MemBuffer: - idx: int - dtype: DType - st: ShapeTracker - -@dataclass(frozen=True) -class ConstBuffer: - val: ConstType | Variable - dtype: DType - st: ShapeTracker - -@dataclass(frozen=True, eq=False) -class LazyOp: - op: Op - src: Tuple[LazyOp, ...] = () - arg: Any = None - def cached_compare(self, x, context): - if id(self) == id(x): return True - if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False - if (key := (id(self), id(x))) in context: return context[key] - ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src)) - return ret - def __eq__(self, x): return self.cached_compare(x, context={}) - def __repr__(self:LazyOp): return pretty_print(self, lambda x: f'LazyOp({x.op}, arg={x.arg}, src=(%s))') - @functools.cached_property - def dtype(self) -> DType: - if self.op in BufferOps: return self.arg.dtype - if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg - return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype - @functools.cached_property - def full_shape(self) -> Tuple[sint, ...]: - if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape - return tuple(max(x) for x in zip(*[x.full_shape for x in self.src])) - @functools.cached_property - def key(self) -> bytes: - return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest() - @functools.cached_property - def hash(self): return hash((self.op, self.src, self.arg)) - def __hash__(self): return self.hash - @functools.cached_property - def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops]) - def vars(self) -> List[Variable]: - extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps] - const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)] - return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr) - def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x)) - def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x)) - def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x)) - def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x)) - def eq(self, x:LazyOp): return -self.ne(x) - def __neg__(self): return LazyOp(UnaryOps.NEG, (self,)) - @staticmethod - def const(val, dtype:DType, shape:Tuple[sint, ...]): - return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape))) - -# the living definition of LazyOps -def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]: - assert ast.op is MetaOps.KERNEL, "must be SINK" - sts: Dict[LazyOp, ShapeTracker] = {} - def assert_valid(op:LazyOp, st:ShapeTracker): - if op in sts: return - # restore globals from the two stage reduce - if op.op is BufferOps.LOAD and op.arg.idx < 0: - assert_valid(local_reduce:=op.src[0].src[0], op.arg.st) - return sts.setdefault(op, sts[local_reduce]) - for x in op.src: assert_valid(x, st) - # only reduceop is allowed to change shape, limited to turning n to 1 - if op.op in ReduceOps: - axis = op.arg - assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}" - st = ShapeTracker.from_shape(sts[op.src[0]].reduce(axis)) - else: - # movementops are pushed to the edges with LOAD - # elementwise inherits shape - st = op.arg.st if op.op in BufferOps else sts[op.src[0]] - for x in op.src: - if sts[x].shape != st.shape: - if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}") - raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}") - sts[op] = st - for i, out in enumerate(ast.src): - assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}" - assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}" - assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}" - assert_valid(out, out.arg.st) - shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] - assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}" - return sts - -def to_uop(*a) -> UOp: - assert isinstance(a[0], LazyOp), f"{a} must be a LazyOp ast" - if a[0].op is BufferOps.STORE: ast = LazyOp(MetaOps.KERNEL, a) - else: - assert a[0].op is MetaOps.KERNEL - ast = a[0] - verify_lazyop(ast) - @functools.lru_cache(None) - def create_uop(lop:LazyOp) -> UOp: - if lop.op in BufferOps: - st_uop = lop.arg.st.to_uop() - membuf_dtype: DType = lop.arg.dtype - dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype - if lop.op is BufferOps.CONST: - return UOp(UOps.CONST, dtype, (st_uop,), lop.arg.val) - buf = UOp(UOps.DEFINE_GLOBAL, membuf_dtype if isinstance(membuf_dtype, ImageDType) else PtrDType(membuf_dtype), (), lop.arg.idx) - if lop.op is BufferOps.LOAD: return UOp(UOps.LOAD, dtype, (buf, st_uop)) - return UOp(UOps.STORE, dtypes.void, (buf, st_uop, create_uop(lop.src[0]))) - src = tuple(create_uop(x) for x in lop.src) - if lop.op is MetaOps.KERNEL: return UOp(UOps.SINK, dtypes.void, src) - if lop.op in ReduceOps: - alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, lop.op)] - return UOp(UOps.REDUCE_AXIS, src[0].dtype, src, (alu_op, lop.arg)) - if lop.op is UnaryOps.CAST: return UOp(UOps.CAST, lop.arg.scalar(), src) - if lop.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, lop.arg.scalar(), src) - return src[0].alu(lop.op, *src[1:]) - ret = create_uop(ast) - #with open("/tmp/ast", "w") as f: f.write(str(ret)) - return ret diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 7e87a5c60d..a102a9a8e2 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -1,9 +1,8 @@ # stuff needed to unpack a kernel from typing import Tuple -from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps from tinygrad import Variable from tinygrad.codegen.kernel import Opt, OptOps -from tinygrad.ops import UOp, UOps, KernelInfo +from tinygrad.ops import UOp, UOps, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View