mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
s/BUFFER_UOPS/BUFOPS (#7501)
This commit is contained in:
@@ -5,8 +5,8 @@ from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, \
|
||||
graph_rewrite, track_rewrites, Variable, sint
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, Variable, sint, \
|
||||
graph_rewrite, track_rewrites
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -68,10 +68,10 @@ class Kernel:
|
||||
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
|
||||
|
||||
self.vars: List[Variable] = self.ast.variables()
|
||||
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS]
|
||||
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFOPS]
|
||||
|
||||
# get earlybufs, before any reduceops
|
||||
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in BUFFER_UOPS]
|
||||
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in BUFOPS]
|
||||
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
||||
# NOTE: full_shape can be wrong if there's a tree of reduces
|
||||
|
||||
@@ -598,7 +598,7 @@ class Kernel:
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# kernel name (before late upcast)
|
||||
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")
|
||||
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in BUFOPS for x in self.ast.parents) else "E")
|
||||
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
||||
|
||||
@@ -613,7 +613,7 @@ class Kernel:
|
||||
@functools.lru_cache(None)
|
||||
def fixup_ast(op:UOp, apply_to_st=None) -> UOp:
|
||||
arg = op.arg
|
||||
if op.op in BUFFER_UOPS:
|
||||
if op.op in BUFOPS:
|
||||
# for locals, we use the ShapeTracker that's in the srcs
|
||||
st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
|
||||
@@ -728,7 +728,7 @@ class Kernel:
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFOPS and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -143,7 +143,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before ALU
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
])
|
||||
|
||||
@@ -261,7 +261,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
metadata: List[Set[Metadata]] = []
|
||||
for stores in store_groups:
|
||||
sink = UOp.sink(*(ctx.realizes[u] for u in stores))
|
||||
metadata.append({mx for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
|
||||
metadata.append({mx for x in sink.sparents if x.op in BUFOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))})
|
||||
small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned))
|
||||
|
||||
# do BFS
|
||||
|
||||
@@ -183,7 +183,7 @@ class Ops(FastEnum):
|
||||
VCONST = auto()
|
||||
CONST = auto()
|
||||
|
||||
BUFFER_UOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
||||
BUFOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
||||
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
@@ -260,7 +260,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
if not self.has_st: return None
|
||||
if self.op in BUFFER_UOPS: return self.st_arg
|
||||
if self.op in BUFOPS: return self.st_arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
src_sts = [x.st for x in self.src if x.st is not None]
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
@@ -293,7 +293,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
||||
assert self.op in BUFOPS, f"st_arg called on {self.op}"
|
||||
ret = self.src[0 if self.op is Ops.VALID else 1]
|
||||
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
@@ -368,7 +368,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFOPS]
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
@@ -1139,4 +1139,4 @@ renderer = PatternMatcher([
|
||||
sint = Union[int, UOp]
|
||||
Variable = UOp
|
||||
|
||||
ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]
|
||||
ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]
|
||||
|
||||
Reference in New Issue
Block a user