mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-11 07:58:08 -05:00
support opts in contig, simpler (#12400)
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import RANGEIFY, CPU_LLVM
|
||||
from tinygrad.helpers import CPU_LLVM
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
@unittest.skipIf(RANGEIFY>0, "arg is partial contig in rangeify")
|
||||
class TestOpts(unittest.TestCase):
|
||||
def test_opt_upcast(self):
|
||||
opts = (Opt(OptOps.UPCAST, 0, 4),)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
|
||||
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -242,6 +242,9 @@ class Scheduler:
|
||||
if not (axis < len(axis_choices)): continue
|
||||
axes = list(axis_choices[axis])
|
||||
|
||||
# tag the reduceop
|
||||
self.ast = self.ast.substitute({reduceop: reduceop.replace(tag="TC")})
|
||||
|
||||
# do optimizations and save the ranges
|
||||
try:
|
||||
for i,a in enumerate(axes):
|
||||
@@ -271,7 +274,7 @@ class Scheduler:
|
||||
|
||||
if use_tensor_cores != 2:
|
||||
# fix the srcs
|
||||
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
|
||||
reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"])
|
||||
tne = [x.replace(tag=1) for x in ne]
|
||||
ret = reduceop.substitute(dict(zip(ne, tne)))
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
|
||||
@@ -2,12 +2,13 @@ from typing import Any, cast, Iterator
|
||||
import functools, operator, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# *****************
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
@@ -555,6 +556,7 @@ class LocalAddBufferContext:
|
||||
vars:dict = field(default_factory=dict)
|
||||
range:int = 0
|
||||
parent_tags:list = field(default_factory=list)
|
||||
opts:tuple|None = None
|
||||
|
||||
def debuf(ctx:LocalAddBufferContext, buf:UOp):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg)
|
||||
@@ -596,10 +598,16 @@ to_define_global = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="r"), renumber_range),
|
||||
])
|
||||
|
||||
def get_contiguous(ctx:LocalAddBufferContext, x:UOp):
|
||||
if isinstance(x.arg, tuple) and all(isinstance(y, Opt) for y in x.arg): ctx.opts = x.arg
|
||||
return x.src[0]
|
||||
|
||||
rangeify_codegen = PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), get_contiguous),
|
||||
|
||||
# no NOOP in the kernel graph
|
||||
# TODO: this can be moved into codegen?
|
||||
(UPat((Ops.NOOP, Ops.CONTIGUOUS), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0]),
|
||||
|
||||
# strip the arg from store
|
||||
(UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None),
|
||||
@@ -640,7 +648,8 @@ def split_store(ctx:list[UOp], x:UOp):
|
||||
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
|
||||
|
||||
# NOTE: the hack for COPY is here
|
||||
ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
|
||||
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) \
|
||||
if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
|
||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||
return x.as_buf().assign(kernel)
|
||||
|
||||
@@ -79,10 +79,10 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
|
||||
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
try:
|
||||
if len(rngs:=u.ranges):
|
||||
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
|
||||
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
elif len(rngs:=u.ranges):
|
||||
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
|
||||
if u.op is Ops.INDEX:
|
||||
label += f"\n{u.render()}"
|
||||
except Exception:
|
||||
|
||||
Reference in New Issue
Block a user