diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index f0d1d2be28..2f8605c978 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -21,6 +21,16 @@ class OptOps(Enum): class KernelOptError(Exception): pass +@dataclass(frozen=True) +class KernelInfo: + full_shape: Tuple[sint, ...] # full shape (findable in AST) + global_dims: int # number of global dimensions (this is remapping RANGE to SPECIAL) + first_reduce: int # axis of first_reduce (findable in AST) + group_for_reduces: int # count that are grouped (findable in AST) + upcasted: int # count that are upcasted (this is remapping RANGE to EXPAND) + @property + def shape_len(self): return len(self.full_shape) + def check(cond:bool, msg:str=""): if not cond: raise KernelOptError(msg) @@ -630,7 +640,7 @@ class Kernel: suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else "" return name+colored(suffix, 'BLACK') - def get_optimized_ast(self) -> Tuple[LazyOp, ...]: + def get_optimized_ast(self) -> Tuple[Tuple[LazyOp, ...], KernelInfo]: # set the shapetrackers to the optimized ones, fixup reduceop # transformed to the final LazyOp @functools.lru_cache(None) @@ -695,4 +705,5 @@ class Kernel: else: arg = op.arg return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg) - return tuple(fixup_ast(x) for x in self.ast) + return tuple(fixup_ast(x) for x in self.ast), \ + KernelInfo(self.full_shape, self.global_dims, self.first_reduce, self.group_for_reduces, self.upcasted) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 627590bfb0..96c679492c 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -61,11 +61,13 @@ class Lowerer(Kernel): if x.arg.idx == -1: buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size)) else: + # NOTE: outbufs is quickly findable in AST buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), (x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs))) if x.op is BufferOps.LOAD: barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else () return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier) + # TODO: what is this? if self.group_for_reduces > 0 and x.arg.idx != -1: valid, has_valid = valid * self.idxs[self.first_reduce].eq(0), True return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ())) @@ -73,7 +75,6 @@ class Lowerer(Kernel): if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops) if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops) if x.op in ReduceOps: - # NOTE: always using ridxs is fine here dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype if x.op is ReduceOps.WMMA: wmma_sz, upcast_axis = x.arg[4], x.arg[6] @@ -82,20 +83,20 @@ class Lowerer(Kernel): UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)), UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2]) + # NOTE: always using ridxs is fine here return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op) return UOp.alu(x.op, *in_uops) def linearize(self) -> Lowerer: - modified_ast = self.get_optimized_ast() + modified_ast, ki = self.get_optimized_ast() if DEBUG >= 4: from tinygrad.engine.graph import print_tree for mast in modified_ast: print_tree(mast) - self.idxs = [] if self.opts.has_local: # define indexes - global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0) - local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501 + global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, ki.full_shape[:ki.global_dims], 3) + local_idxs, loop_local_idxs = get_grouped_dims("lidx", ki.global_dims, ki.full_shape[ki.global_dims:ki.first_reduce+ki.group_for_reduces], 3) self.idxs = global_idxs + local_idxs # define sizes @@ -104,26 +105,24 @@ class Lowerer(Kernel): self.global_size += [1]*(3-len(self.global_size)) self.local_size += [1]*(3-len(self.local_size)) else: - # all loops - self.idxs = [] - for i,g in enumerate(self.full_shape[:self.first_reduce]): - self.idxs.append(UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, False))) + # all loops are RANGES + self.idxs = [UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, False)) + for i,g in enumerate(ki.full_shape[:ki.first_reduce])] self.global_size, self.local_size = None, None # reduce loops - for i,g in enumerate(self.full_shape[self.first_reduce+self.group_for_reduces:], start=self.first_reduce+self.group_for_reduces): - unrolled, is_reduce = i >= (self.shape_len-self.upcasted), self.full_shape[i] != self.output_shape[i] - if unrolled: - assert isinstance(g, int), "needs to be int to unroll" - uop = UOp(UOps.EXPAND, dtypes.int32, tuple(UOp.const(dtypes.int32, j) for j in range(0, g)), i) - else: - uop = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, is_reduce)) - self.idxs.append(uop) + self.idxs += [UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, True)) + for i,g in enumerate(ki.full_shape[ki.first_reduce+ki.group_for_reduces:ki.shape_len-ki.upcasted], start=ki.first_reduce+ki.group_for_reduces)] - # late indexes + # upcast loops + for i,g in enumerate(ki.full_shape[ki.shape_len-ki.upcasted:], start=ki.shape_len-ki.upcasted): + assert isinstance(g, int), "needs to be int to upcast/unroll" + self.idxs.append(UOp(UOps.EXPAND, dtypes.int32, tuple(UOp.const(dtypes.int32, j) for j in range(0, g)), i)) + + # late indexes (group for reduce) self.ridxs = self.idxs[:] - for a in range(self.first_reduce, self.first_reduce+self.group_for_reduces): - self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(self.full_shape[a])), (1000+a, True)) + for a in range(ki.first_reduce, ki.first_reduce+ki.group_for_reduces): + self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(ki.full_shape[a])), (1000+a, True)) self.uop_cache: Dict[LazyOp, UOp] = {} self.uops:UOpGraph = UOpGraph([self.to_uop(x) for x in modified_ast], self.opts)