some more ops.py cleanups (#12525)

* remove GroupOp.Meta and st_arg

* inline axis_arg

* only allow .buffer on reshapes (or the buffer)

* gate is the other way

* still want can_pad?

* use op_in_backward_slice_with_self

* .buffer is recursive

* lint

* pathlib there
This commit is contained in:
qazal
2025-10-09 06:06:44 +03:00
committed by GitHub
parent be05028419
commit bb5671a837
5 changed files with 10 additions and 22 deletions

View File

@@ -432,7 +432,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
assert isinstance(ast, UOp), "ast must be UOp"
inbufs = [x.uop.base.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() for out in ast.src]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.size, out.src[1].dtype).allocate() for out in ast.src]
_helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs):

View File

@@ -2289,7 +2289,7 @@ class TestBufferUOp(unittest.TestCase):
def test_buffer_view_not_allowed(self):
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"):
with self.assertRaisesRegex(AssertionError, "can only be RESHAPE"):
permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW
def test_buffer_only_after_realize(self):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import math, itertools
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.device import Buffer
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
@@ -188,7 +188,7 @@ class Scheduler:
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
check(r.arg[0] is Ops.ADD and not r.op_in_backward_slice_with_self(*GroupOp.UnsafePad), f"cannot pad {r}")
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
replaced_rng = UOp.range(new_sz, *rng.arg)

View File

@@ -108,6 +108,4 @@ class GroupOp:
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
Meta = {Ops.COPY, Ops.BUFFER_VIEW}
All = set(Ops)

View File

@@ -23,9 +23,6 @@ range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
def can_pad(root:UOp, edges:dict[UOp, None]) -> bool:
return all(u.op not in GroupOp.UnsafePad for u in root.toposort(gate=lambda x:x not in edges))
# With True as the default, this matches the old symbolic behavior
def resolve(x:UOp|bool, default:bool=True):
if isinstance(x, bool): return x
@@ -223,7 +220,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
case Ops.BITCAST:
shape = src_sts[0].shape
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
case Ops.REDUCE_AXIS | Ops.WMMA: shape = src_sts[0].reduce(self.axis_arg)
case Ops.REDUCE_AXIS | Ops.WMMA:
axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
assert isinstance(axis_arg, tuple) and all(isinstance(x, int) for x in axis_arg), f"invalid type for axis: {axis_arg}"
shape = src_sts[0].reduce(axis_arg)
case _: shape = src_sts[0].shape
return ShapeTracker.from_shape(shape)
@@ -286,16 +286,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop syntactic sugar ***
@property
def st_arg(self) -> ShapeTracker:
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
return unwrap(self.st)
@property
def axis_arg(self) -> tuple[int, ...]:
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
return ret
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
@@ -501,7 +491,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def buffer(self) -> Buffer|MultiBuffer:
from tinygrad.device import Buffer, MultiBuffer
if self is not self.base:
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
assert self.op is Ops.RESHAPE, f"can only be RESHAPE {self}"
return self.src[0].buffer
if self.op is Ops.MSELECT:
ret = self.src[0].buffer
@@ -992,7 +982,7 @@ if TRACK_MATCH_STATS or PROFILE:
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")) and not int(os.getenv("SQTT", "0")):
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), "../", "viz", "serve.py")] + args)
os.execv(sys.executable, [sys.executable] + [pathlib.Path(__file__).resolve().parent.parent / "viz" / "serve.py"] + args)
# *** simple graph rewrite engine ***