s/BUFFER_UOPS/BUFOPS (#7501)

This commit is contained in:
qazal
2024-11-03 10:17:33 +02:00
committed by GitHub
parent c8bf09b7d4
commit 37f8578953
3 changed files with 15 additions and 15 deletions

View File

@@ -5,8 +5,8 @@ from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
from enum import Enum, auto
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, \
graph_rewrite, track_rewrites, Variable, sint
from tinygrad.ops import UNSAFE_PAD_OPS, BUFOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, Variable, sint, \
graph_rewrite, track_rewrites
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import ImageDType
@@ -68,10 +68,10 @@ class Kernel:
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
self.vars: List[Variable] = self.ast.variables()
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS]
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFOPS]
# get earlybufs, before any reduceops
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in BUFFER_UOPS]
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in BUFOPS]
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
# NOTE: full_shape can be wrong if there's a tree of reduces
@@ -598,7 +598,7 @@ class Kernel:
@functools.cached_property
def name(self) -> str:
# kernel name (before late upcast)
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in BUFOPS for x in self.ast.parents) else "E")
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
@@ -613,7 +613,7 @@ class Kernel:
@functools.lru_cache(None)
def fixup_ast(op:UOp, apply_to_st=None) -> UOp:
arg = op.arg
if op.op in BUFFER_UOPS:
if op.op in BUFOPS:
# for locals, we use the ShapeTracker that's in the srcs
st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
@@ -728,7 +728,7 @@ class Kernel:
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is Ops.DEFINE_GLOBAL],
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFOPS and x.src[0].op is Ops.DEFINE_GLOBAL],
key=lambda x: (x.op, x.src[0].arg)))
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)

View File

@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -143,7 +143,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view before ALU
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
])
@@ -261,7 +261,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
metadata: List[Set[Metadata]] = []
for stores in store_groups:
sink = UOp.sink(*(ctx.realizes[u] for u in stores))
metadata.append({mx for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
metadata.append({mx for x in sink.sparents if x.op in BUFOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned))
# do BFS

View File

@@ -183,7 +183,7 @@ class Ops(FastEnum):
VCONST = auto()
CONST = auto()
BUFFER_UOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
BUFOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
@@ -260,7 +260,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
if not self.has_st: return None
if self.op in BUFFER_UOPS: return self.st_arg
if self.op in BUFOPS: return self.st_arg
if self.op is Ops.VIEW: return self.arg
src_sts = [x.st for x in self.src if x.st is not None]
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
@@ -293,7 +293,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def st_arg(self) -> ShapeTracker:
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
assert self.op in BUFOPS, f"st_arg called on {self.op}"
ret = self.src[0 if self.op is Ops.VALID else 1]
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
return ret.arg
@@ -368,7 +368,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFOPS]
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# *** uop symbolic stuff ***
@@ -1139,4 +1139,4 @@ renderer = PatternMatcher([
sint = Union[int, UOp]
Variable = UOp
ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]
ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]