From bb5671a83796f463d9bd9c72325b3c4916710530 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 9 Oct 2025 06:06:44 +0300 Subject: [PATCH] some more ops.py cleanups (#12525) * remove GroupOp.Meta and st_arg * inline axis_arg * only allow .buffer on reshapes (or the buffer) * gate is the other way * still want can_pad? * use op_in_backward_slice_with_self * .buffer is recursive * lint * pathlib there --- test/test_linearizer.py | 2 +- test/test_schedule.py | 2 +- tinygrad/codegen/opt/postrange.py | 4 ++-- tinygrad/uop/__init__.py | 2 -- tinygrad/uop/ops.py | 22 ++++++---------------- 5 files changed, 10 insertions(+), 22 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c2ae2c3990..2f9667fb12 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -432,7 +432,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]: def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs): assert isinstance(ast, UOp), "ast must be UOp" inbufs = [x.uop.base.buffer for x in inputs] - outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() for out in ast.src] + outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.size, out.src[1].dtype).allocate() for out in ast.src] _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs) def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs): diff --git a/test/test_schedule.py b/test/test_schedule.py index 87c56166cb..e13292e0e5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2289,7 +2289,7 @@ class TestBufferUOp(unittest.TestCase): def test_buffer_view_not_allowed(self): permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1) - with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"): + with self.assertRaisesRegex(AssertionError, "can only be RESHAPE"): permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW def test_buffer_only_after_realize(self): diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 99925fb358..7cd45ef2c7 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -2,7 +2,7 @@ from __future__ import annotations import math, itertools from collections import defaultdict from typing import cast, Final -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp from tinygrad.device import Buffer from tinygrad.dtype import AddrSpace, dtypes, ImageDType from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element @@ -188,7 +188,7 @@ class Scheduler: check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread") # ok to pad SUM if all parent ALU ops have f(0) = 0 if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): - check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}") + check(r.arg[0] is Ops.ADD and not r.op_in_backward_slice_with_self(*GroupOp.UnsafePad), f"cannot pad {r}") new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg)) check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work") replaced_rng = UOp.range(new_sz, *rng.arg) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 95ccb0b54f..2922fd4471 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -108,6 +108,4 @@ class GroupOp: # do not preserve f(0) = 0 UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW} - Meta = {Ops.COPY, Ops.BUFFER_VIEW} - All = set(Ops) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 2622997923..bfa7a911c7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -23,9 +23,6 @@ range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3} # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) -def can_pad(root:UOp, edges:dict[UOp, None]) -> bool: - return all(u.op not in GroupOp.UnsafePad for u in root.toposort(gate=lambda x:x not in edges)) - # With True as the default, this matches the old symbolic behavior def resolve(x:UOp|bool, default:bool=True): if isinstance(x, bool): return x @@ -223,7 +220,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): case Ops.BITCAST: shape = src_sts[0].shape if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,) - case Ops.REDUCE_AXIS | Ops.WMMA: shape = src_sts[0].reduce(self.axis_arg) + case Ops.REDUCE_AXIS | Ops.WMMA: + axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7] + assert isinstance(axis_arg, tuple) and all(isinstance(x, int) for x in axis_arg), f"invalid type for axis: {axis_arg}" + shape = src_sts[0].reduce(axis_arg) case _: shape = src_sts[0].shape return ShapeTracker.from_shape(shape) @@ -286,16 +286,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop syntactic sugar *** - @property - def st_arg(self) -> ShapeTracker: - assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}" - return unwrap(self.st) - @property - def axis_arg(self) -> tuple[int, ...]: - assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}" - ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7] - assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}" - return ret def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) @@ -501,7 +491,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def buffer(self) -> Buffer|MultiBuffer: from tinygrad.device import Buffer, MultiBuffer if self is not self.base: - assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous" + assert self.op is Ops.RESHAPE, f"can only be RESHAPE {self}" return self.src[0].buffer if self.op is Ops.MSELECT: ret = self.src[0].buffer @@ -992,7 +982,7 @@ if TRACK_MATCH_STATS or PROFILE: if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")) and not int(os.getenv("SQTT", "0")): args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else [] args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else [] - os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), "../", "viz", "serve.py")] + args) + os.execv(sys.executable, [sys.executable] + [pathlib.Path(__file__).resolve().parent.parent / "viz" / "serve.py"] + args) # *** simple graph rewrite engine ***