diff --git a/test/test_opts.py b/test/test_opts.py index 4a6310ef32..7bbdcfb61b 100644 --- a/test/test_opts.py +++ b/test/test_opts.py @@ -1,10 +1,9 @@ import unittest from tinygrad import Tensor, Device -from tinygrad.helpers import RANGEIFY, CPU_LLVM +from tinygrad.helpers import CPU_LLVM from tinygrad.codegen.opt import Opt, OptOps from tinygrad.engine.realize import get_program -@unittest.skipIf(RANGEIFY>0, "arg is partial contig in rangeify") class TestOpts(unittest.TestCase): def test_opt_upcast(self): opts = (Opt(OptOps.UPCAST, 0, 4),) diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 1fe4897242..10a415ec4f 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -5,7 +5,7 @@ from typing import cast, Final from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp from tinygrad.device import Buffer from tinygrad.dtype import AddrSpace, dtypes, ImageDType -from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts +from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer @@ -242,6 +242,9 @@ class Scheduler: if not (axis < len(axis_choices)): continue axes = list(axis_choices[axis]) + # tag the reduceop + self.ast = self.ast.substitute({reduceop: reduceop.replace(tag="TC")}) + # do optimizations and save the ranges try: for i,a in enumerate(axes): @@ -271,7 +274,7 @@ class Scheduler: if use_tensor_cores != 2: # fix the srcs - reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0] + reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"]) tne = [x.replace(tag=1) for x in ne] ret = reduceop.substitute(dict(zip(ne, tne))) srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 01eac77fcb..06a930e64c 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,12 +2,13 @@ from typing import Any, cast, Iterator import functools, operator, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo from tinygrad.uop.symbolic import sym, symbolic_simple from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup from tinygrad.schedule.kernelize import Kernel from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType from tinygrad.codegen.simplify import pm_flatten_range +from tinygrad.codegen.opt import Opt # ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff @@ -555,6 +556,7 @@ class LocalAddBufferContext: vars:dict = field(default_factory=dict) range:int = 0 parent_tags:list = field(default_factory=list) + opts:tuple|None = None def debuf(ctx:LocalAddBufferContext, buf:UOp): ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg) @@ -596,10 +598,16 @@ to_define_global = PatternMatcher([ (UPat(Ops.RANGE, name="r"), renumber_range), ]) +def get_contiguous(ctx:LocalAddBufferContext, x:UOp): + if isinstance(x.arg, tuple) and all(isinstance(y, Opt) for y in x.arg): ctx.opts = x.arg + return x.src[0] + rangeify_codegen = PatternMatcher([ + (UPat(Ops.CONTIGUOUS, name="x"), get_contiguous), + # no NOOP in the kernel graph # TODO: this can be moved into codegen? - (UPat((Ops.NOOP, Ops.CONTIGUOUS), name="x"), lambda x: x.src[0]), + (UPat(Ops.NOOP, name="x"), lambda x: x.src[0]), # strip the arg from store (UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None), @@ -640,7 +648,8 @@ def split_store(ctx:list[UOp], x:UOp): metadatas = [ctx[y].metadata for y in lctx.parent_tags] # NOTE: the hack for COPY is here - ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1] + ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) \ + if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1] kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]) kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) return x.as_buf().assign(kernel) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index b3106c0241..7ac23c89f8 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -79,10 +79,10 @@ def uop_to_json(x:UOp) -> dict[int, dict]: arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}" label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "") try: + if len(rngs:=u.ranges): + label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})" if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None: label += f"\n{shape_to_str(u.shape)}" - elif len(rngs:=u.ranges): - label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})" if u.op is Ops.INDEX: label += f"\n{u.render()}" except Exception: