support opts in contig, simpler (#12400)

This commit is contained in:
George Hotz
2025-10-01 17:20:04 +08:00
committed by GitHub
parent 6c95b1f39d
commit 60e52fbe36
4 changed files with 20 additions and 9 deletions

View File

@@ -1,10 +1,9 @@
import unittest
from tinygrad import Tensor, Device
from tinygrad.helpers import RANGEIFY, CPU_LLVM
from tinygrad.helpers import CPU_LLVM
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import get_program
@unittest.skipIf(RANGEIFY>0, "arg is partial contig in rangeify")
class TestOpts(unittest.TestCase):
def test_opt_upcast(self):
opts = (Opt(OptOps.UPCAST, 0, 4),)

View File

@@ -5,7 +5,7 @@ from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
from tinygrad.device import Buffer
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
@@ -242,6 +242,9 @@ class Scheduler:
if not (axis < len(axis_choices)): continue
axes = list(axis_choices[axis])
# tag the reduceop
self.ast = self.ast.substitute({reduceop: reduceop.replace(tag="TC")})
# do optimizations and save the ranges
try:
for i,a in enumerate(axes):
@@ -271,7 +274,7 @@ class Scheduler:
if use_tensor_cores != 2:
# fix the srcs
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"])
tne = [x.replace(tag=1) for x in ne]
ret = reduceop.substitute(dict(zip(ne, tne)))
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)

View File

@@ -2,12 +2,13 @@ from typing import Any, cast, Iterator
import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo
from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.codegen.opt import Opt
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
@@ -555,6 +556,7 @@ class LocalAddBufferContext:
vars:dict = field(default_factory=dict)
range:int = 0
parent_tags:list = field(default_factory=list)
opts:tuple|None = None
def debuf(ctx:LocalAddBufferContext, buf:UOp):
ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg)
@@ -596,10 +598,16 @@ to_define_global = PatternMatcher([
(UPat(Ops.RANGE, name="r"), renumber_range),
])
def get_contiguous(ctx:LocalAddBufferContext, x:UOp):
if isinstance(x.arg, tuple) and all(isinstance(y, Opt) for y in x.arg): ctx.opts = x.arg
return x.src[0]
rangeify_codegen = PatternMatcher([
(UPat(Ops.CONTIGUOUS, name="x"), get_contiguous),
# no NOOP in the kernel graph
# TODO: this can be moved into codegen?
(UPat((Ops.NOOP, Ops.CONTIGUOUS), name="x"), lambda x: x.src[0]),
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0]),
# strip the arg from store
(UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None),
@@ -640,7 +648,8 @@ def split_store(ctx:list[UOp], x:UOp):
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
# NOTE: the hack for COPY is here
ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) \
if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
return x.as_buf().assign(kernel)

View File

@@ -79,10 +79,10 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
try:
if len(rngs:=u.ranges):
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
label += f"\n{shape_to_str(u.shape)}"
elif len(rngs:=u.ranges):
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
if u.op is Ops.INDEX:
label += f"\n{u.render()}"
except Exception: