import functools from tinygrad.dtype import AddrSpace from tinygrad.mixin import MathMixin from tinygrad.uop.ops import UOp, Ops from extra.thunder.tiny.tk import WARP_THREADS def unwrap(x): if hasattr(x, "_uop"): return x._uop if isinstance(x, (list, tuple)): return type(x)(unwrap(y) for y in x) if isinstance(x, dict): return {k: unwrap(v) for k,v in x.items()} return x def wrap(x, ker, cls): if isinstance(x, UOp): return cls(x, ker) if isinstance(x, (list, tuple)): return type(x)(wrap(y, ker, cls) for y in x) return x def autowrap(source_cls, blacklist=None): if blacklist is None: blacklist = { "__init__", "__new__", "__str__", "__del__", "__repr__", "__dict__", "__getattribute__", "__setattr__", "__delattr__", "__weakref__", "__slots__", "__class__", "__reduce__", "__reduce_ex__", "__getstate__", "__setstate__", "__hash__" } def decorator(cls): def __getattr__(self, name): uop = object.__getattribute__(self, "_uop") val = getattr(uop, name) if callable(val): @functools.wraps(val) def proxy(*args, **kwargs): return wrap(val(*unwrap(args), **unwrap(kwargs)), self.ker, cls) return proxy if name in UOp.__slots__: return val return wrap(val, self.ker, cls) cls.__getattr__ = __getattr__ for name in dir(source_cls): if name in blacklist or not name.startswith("__"): continue for base in cls.mro(): if base is source_cls: break if name in base.__dict__: break else: original = getattr(source_cls, name) if callable(original): def make_proxy(op_name, func): def proxy(self, *args, **kwargs): return wrap(func(self._uop, *unwrap(args), **unwrap(kwargs)), self.ker, cls) return proxy setattr(cls, name, make_proxy(name, original)) return cls return decorator class TileMathMixin(MathMixin): def alu(self, op, *src, inner_op=lambda x:x): assert isinstance(self, (RT, RV)) if len(src) == 0: if self._uop._shape is None: uop = UOp.alu(self._uop, op) else: uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op)) elif len(src) == 1: if self._uop._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0]))) elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0])))) elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0]))) else: if isinstance(self, RT) and isinstance(src[0], RV): uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0, (idx[2]%4)//2]))) else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx]))) else: raise NotImplementedError return type(self)(uop, self.ker) def const_like(self, b): return b # override ops that do compute on the src uop def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, self, inner_op=lambda y: -y) if reverse else self.alu(Ops.ADD, self.ufix(x), inner_op=lambda y: -y) def div(self, x, reverse=False): return self.ufix(x).alu(Ops.MUL, self, inner_op=lambda y: 1/y) if reverse else self.alu(Ops.MUL, self.ufix(x), inner_op=lambda y: 1/y) @autowrap(UOp) class GL: def __init__(self, uop, ker): self._uop, self.ker = uop, ker @classmethod def create(cls, shape, dtype, ker): uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL) return cls(uop, ker) @autowrap(UOp) class ST: def __init__(self, uop, ker): self._uop, self.ker = uop, ker @classmethod def create(cls, shape, dtype, ker): uop = ker.alloc(shape, dtype, AddrSpace.LOCAL) return cls(uop, ker) @autowrap(UOp) class RT(TileMathMixin): BASE_TILE_ROWS, BASE_TILE_COLS = 16, 16 BASE_TILE_NE = BASE_TILE_ROWS * BASE_TILE_COLS BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS def __init__(self, uop, ker): self._uop, self.ker = uop, ker @classmethod def create(cls, shape, dtype, ker): assert len(shape) == 2 assert shape[0] % RT.BASE_TILE_ROWS == 0 assert shape[1] % RT.BASE_TILE_COLS == 0 height = shape[0] // RT.BASE_TILE_ROWS width = shape[1] // RT.BASE_TILE_COLS uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG) return cls(uop, ker) @autowrap(UOp) class RV(TileMathMixin): def __init__(self, uop, ker): self._uop, self.ker = uop, ker @classmethod def create(cls, length, dtype, layout, ker): tiles = length // RT.BASE_TILE_ROWS match layout: case "naive": inner_dim = 1 outer_dim = (tiles + 1) // 2 case "ortho": inner_dim = 1 outer_dim = tiles case _: raise NotImplementedError(f"rv layout {layout} not implemented") uop = ker.alloc((outer_dim, inner_dim, 2), dtype, AddrSpace.REG) return RV(uop, ker) ALL_TILES = UOp | GL | ST | RT | RV