diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 4a42b1ed03..c6f0bcb9d2 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -18,7 +18,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner from tinygrad.engine.graph import print_tree -from tinygrad.helpers import DEBUG, prod, Context, getenv, CI +from tinygrad.helpers import DEBUG, prod, Context, getenv, CI, flatten, dedup from tinygrad.dtype import DType, dtypes def helper_realized_ast(r:Union[Tensor, List[Tensor]]): @@ -620,7 +620,8 @@ class TestLinearizer(unittest.TestCase): def test_grouped_dims(self): def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes): # TODO: fix reverse_dims - idxs, loop_idxs = get_grouped_dims(prefix, 0, dims, max_sizes) + idxs = get_grouped_dims(prefix, 0, dims, max_sizes) + loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs])) sizes = [x.arg[2] for x in loop_idxs] assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}" diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index a4be11e0e2..d55b908f56 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -7,7 +7,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, get_lazyop_info, KernelInfo from tinygrad.codegen.uops import UOp, flops_mem, UOps from tinygrad.codegen.uopgraph import UOpGraph -from tinygrad.renderer import Program +from tinygrad.renderer import Program, Renderer from tinygrad.helpers import to_function_name, DEBUG, getenv, prod, diskcache_put, ContextVar # TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps @@ -55,10 +55,10 @@ else: assert uvalid.dtype == dtypes.bool return uidx, uvalid -def get_grouped_dims(prefix, start_dim, dims, max_sizes:Optional[Tuple[int, ...]]) -> Tuple[List[UOp], List[UOp]]: +def get_grouped_dims(prefix, start_dim, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]: # TODO: this should be per dim max maxdim = len(max_sizes) if max_sizes is not None else 0 - local_idxs = loop_local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (), + local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (), (i, f"{prefix}{start_dim+i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)] if maxdim != 0 and len(dims) > maxdim: dd = local_idxs[0] @@ -67,9 +67,50 @@ def get_grouped_dims(prefix, start_dim, dims, max_sizes:Optional[Tuple[int, ...] nli.append(dd % s) dd //= s local_idxs = nli + local_idxs[-(maxdim-1):] - return local_idxs, loop_local_idxs + return local_idxs + +class IndependentLowerer: + def lower(self, ast:LazyOp, opts:Renderer) -> UOp: + self.output_count = len(ast.src) + + ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo() + # NOTE: assumes the shape is + full_shape = ast.full_shape + first_reduce = [x!=y for x,y in zip(ast.src[0].arg.st.shape[:len(full_shape)-ki.upcasted]+(0,), + full_shape[:len(full_shape)-ki.upcasted]+(1,))].index(True) + local_loads = [x for x in ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1] + # NOTE: this is taking the first one...there may be subtlelies here with multireduces + group_for_reduces = sum([x!=y for x,y in zip( + local_loads[0].arg.st.shape[first_reduce:len(full_shape)-ki.upcasted], + ast.src[0].arg.st.shape[first_reduce:len(full_shape)-ki.upcasted])]) if len(local_loads) else 0 + global_dims = first_reduce-ki.local_dims + + if opts.has_local: + # define indexes for GPU-like execution + self.idxs = get_grouped_dims("gidx", 0, full_shape[:global_dims], opts.global_max) + \ + get_grouped_dims("lidx", global_dims, full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) + else: + # all loops are RANGES + self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False)) + for i,g in enumerate(full_shape[:first_reduce])] + + # reduce loops + self.idxs += [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, True)) + for i,g in enumerate(full_shape[first_reduce+group_for_reduces:len(full_shape)-ki.upcasted], start=first_reduce+group_for_reduces)] + + # upcast loops + for i,g in enumerate(full_shape[len(full_shape)-ki.upcasted:], start=len(full_shape)-ki.upcasted): + assert isinstance(g, int), "needs to be int to upcast/unroll" + self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), i)) + + # late indexes (group for reduce) + self.ridxs = self.idxs[:] + for a in range(first_reduce, first_reduce+group_for_reduces): + self.ridxs[a] = UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(full_shape[a])), (1000+a, True)) + + self.uop_cache: Dict[LazyOp, UOp] = {} + return self.to_uop(ast) -class Lowerer(Kernel): def to_uop(self, x:LazyOp) -> UOp: if uop:=self.uop_cache.get(x, None): return uop ret = self._to_uop(x) @@ -88,7 +129,7 @@ class Lowerer(Kernel): buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size)) else: buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), - (x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast.src))) + (x.arg.idx, x.arg.idx < self.output_count)) if x.op is BufferOps.LOAD: barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else () return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier) @@ -115,62 +156,33 @@ class Lowerer(Kernel): # NOTE: always using ridxs is fine here return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op) return UOp.alu(x.op, *in_uops) +def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts) +# TODO: move this to Kernel +class Lowerer(Kernel): def linearize(self) -> Lowerer: modified_ast = self.get_optimized_ast() - ki = modified_ast.arg if isinstance(modified_ast.arg, KernelInfo) else KernelInfo() - # NOTE: assumes the shape is - full_shape = modified_ast.full_shape - first_reduce = [x!=y for x,y in zip(modified_ast.src[0].arg.st.shape[:len(full_shape)-ki.upcasted]+(0,), - full_shape[:len(full_shape)-ki.upcasted]+(1,))].index(True) - local_loads = [x for x in modified_ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1] - # NOTE: this is taking the first one...there may be subtlelies here with multireduces - group_for_reduces = sum([x!=y for x,y in zip( - local_loads[0].arg.st.shape[first_reduce:len(full_shape)-ki.upcasted], - modified_ast.src[0].arg.st.shape[first_reduce:len(full_shape)-ki.upcasted])]) if len(local_loads) else 0 - global_dims = first_reduce-ki.local_dims if DEBUG >= 3: print(self.name) from tinygrad.engine.graph import print_tree print_tree(modified_ast) - if self.opts.has_local: - # define indexes - global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, full_shape[:global_dims], self.opts.global_max) - local_idxs, loop_local_idxs = get_grouped_dims("lidx", global_dims, - full_shape[global_dims:first_reduce+group_for_reduces], self.opts.local_max) - self.idxs = global_idxs + local_idxs + uop_sink = lazyop_to_uop(modified_ast, self.opts) - # define sizes - self.global_size: Optional[List[int]] = [x.arg[2] for x in loop_global_idxs] - self.local_size: Optional[List[int]] = [x.arg[2] for x in loop_local_idxs] - self.global_size += [1]*(3-len(self.global_size)) - self.local_size += [1]*(3-len(self.local_size)) + # extract global/local sizes + if self.opts.has_local: + self.global_size: Optional[List[int]] = [1,1,1] + self.local_size: Optional[List[int]] = [1,1,1] + for u in uop_sink.parents: + if u.op is UOps.SPECIAL: + if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2] + else: self.global_size[u.arg[0]] = u.arg[2] else: - # all loops are RANGES - self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False)) - for i,g in enumerate(full_shape[:first_reduce])] self.global_size, self.local_size = None, None - # reduce loops - self.idxs += [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, True)) - for i,g in enumerate(full_shape[first_reduce+group_for_reduces:len(full_shape)-ki.upcasted], start=first_reduce+group_for_reduces)] - - # upcast loops - for i,g in enumerate(full_shape[len(full_shape)-ki.upcasted:], start=len(full_shape)-ki.upcasted): - assert isinstance(g, int), "needs to be int to upcast/unroll" - self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), i)) - - # late indexes (group for reduce) - self.ridxs = self.idxs[:] - for a in range(first_reduce, first_reduce+group_for_reduces): - self.ridxs[a] = UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(full_shape[a])), (1000+a, True)) - - self.uop_cache: Dict[LazyOp, UOp] = {} - self.uops:UOpGraph = UOpGraph(self.to_uop(modified_ast), self.opts) - - # maybe graph the uops + # generate the UOpGraph + self.uops:UOpGraph = UOpGraph(uop_sink, self.opts) if DEBUG >= 5: self.uops.print() if getenv("GRAPHUOPS"): self.uops.graph() diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 2de6bd8039..c555122ca0 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -110,6 +110,7 @@ def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else # HACK: staticmethods are not callable in 3.8 so we have to compare the class DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'bigint')) or v.__class__ is staticmethod)} INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()} +INVERSE_DTYPES_DICT['bigint'] = 'bigint' def sum_acc_dtype(dt:DType): # default acc dtype for sum