diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 1cfb5a9345..70f8f1e20f 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -5,8 +5,8 @@ from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence from enum import Enum, auto -from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, \ - graph_rewrite, track_rewrites, Variable, sint +from tinygrad.ops import UNSAFE_PAD_OPS, BUFOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, Variable, sint, \ + graph_rewrite, track_rewrites from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.dtype import ImageDType @@ -68,10 +68,10 @@ class Kernel: self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS]) self.vars: List[Variable] = self.ast.variables() - self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS] + self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFOPS] # get earlybufs, before any reduceops - earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in BUFFER_UOPS] + earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in BUFOPS] self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0 # NOTE: full_shape can be wrong if there's a tree of reduces @@ -598,7 +598,7 @@ class Kernel: @functools.cached_property def name(self) -> str: # kernel name (before late upcast) - kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E") + kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in BUFOPS for x in self.ast.parents) else "E") suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())]) name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix @@ -613,7 +613,7 @@ class Kernel: @functools.lru_cache(None) def fixup_ast(op:UOp, apply_to_st=None) -> UOp: arg = op.arg - if op.op in BUFFER_UOPS: + if op.op in BUFOPS: # for locals, we use the ShapeTracker that's in the srcs st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)] st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop() @@ -728,7 +728,7 @@ class Kernel: # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes # TODO: these max and min don't work on symbolic, and results are very wrong. mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group) - for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is Ops.DEFINE_GLOBAL], + for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFOPS and x.src[0].op is Ops.DEFINE_GLOBAL], key=lambda x: (x.op, x.src[0].arg))) return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes, global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 70b3dcc700..4e8412df7c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,7 +2,7 @@ import sys, atexit, functools, itertools from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast -from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint +from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap from tinygrad.dtype import ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -143,7 +143,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), # push VIEW to loads view_left = merge_views+PatternMatcher([ # view before ALU - (UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"), + (UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), ]) @@ -261,7 +261,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] metadata: List[Set[Metadata]] = [] for stores in store_groups: sink = UOp.sink(*(ctx.realizes[u] for u in stores)) - metadata.append({mx for x in sink.sparents if x.op in BUFFER_UOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))}) + metadata.append({mx for x in sink.sparents if x.op in BUFOPS and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.src[0]))}) small_graphs.append(full_ast_rewrite(sink, ctx.var_vals, assigned)) # do BFS diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ce7b7e7ad1..2ac0739bd4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -183,7 +183,7 @@ class Ops(FastEnum): VCONST = auto() CONST = auto() -BUFFER_UOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID} +BUFOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID} COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR} END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)} @@ -260,7 +260,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def st(self) -> Optional[ShapeTracker]: if not self.has_st: return None - if self.op in BUFFER_UOPS: return self.st_arg + if self.op in BUFOPS: return self.st_arg if self.op is Ops.VIEW: return self.arg src_sts = [x.st for x in self.src if x.st is not None] assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" @@ -293,7 +293,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def st_arg(self) -> ShapeTracker: - assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}" + assert self.op in BUFOPS, f"st_arg called on {self.op}" ret = self.src[0 if self.op is Ops.VALID else 1] assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}" return ret.arg @@ -368,7 +368,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR]) return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) def variables(self) -> List[Variable]: - st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] + st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFOPS] return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) # *** uop symbolic stuff *** @@ -1139,4 +1139,4 @@ renderer = PatternMatcher([ sint = Union[int, UOp] Variable = UOp -ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]] \ No newline at end of file +ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]