mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
some more ops.py cleanups (#12525)
* remove GroupOp.Meta and st_arg * inline axis_arg * only allow .buffer on reshapes (or the buffer) * gate is the other way * still want can_pad? * use op_in_backward_slice_with_self * .buffer is recursive * lint * pathlib there
This commit is contained in:
@@ -432,7 +432,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
|
||||
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
|
||||
assert isinstance(ast, UOp), "ast must be UOp"
|
||||
inbufs = [x.uop.base.buffer for x in inputs]
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() for out in ast.src]
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.size, out.src[1].dtype).allocate() for out in ast.src]
|
||||
_helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
|
||||
|
||||
def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs):
|
||||
|
||||
@@ -2289,7 +2289,7 @@ class TestBufferUOp(unittest.TestCase):
|
||||
|
||||
def test_buffer_view_not_allowed(self):
|
||||
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
|
||||
with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"):
|
||||
with self.assertRaisesRegex(AssertionError, "can only be RESHAPE"):
|
||||
permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW
|
||||
|
||||
def test_buffer_only_after_realize(self):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import math, itertools
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
|
||||
@@ -188,7 +188,7 @@ class Scheduler:
|
||||
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
|
||||
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
|
||||
check(r.arg[0] is Ops.ADD and not r.op_in_backward_slice_with_self(*GroupOp.UnsafePad), f"cannot pad {r}")
|
||||
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
|
||||
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
|
||||
replaced_rng = UOp.range(new_sz, *rng.arg)
|
||||
|
||||
@@ -108,6 +108,4 @@ class GroupOp:
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
||||
|
||||
Meta = {Ops.COPY, Ops.BUFFER_VIEW}
|
||||
|
||||
All = set(Ops)
|
||||
|
||||
@@ -23,9 +23,6 @@ range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3}
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
def can_pad(root:UOp, edges:dict[UOp, None]) -> bool:
|
||||
return all(u.op not in GroupOp.UnsafePad for u in root.toposort(gate=lambda x:x not in edges))
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
def resolve(x:UOp|bool, default:bool=True):
|
||||
if isinstance(x, bool): return x
|
||||
@@ -223,7 +220,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
case Ops.BITCAST:
|
||||
shape = src_sts[0].shape
|
||||
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
|
||||
case Ops.REDUCE_AXIS | Ops.WMMA: shape = src_sts[0].reduce(self.axis_arg)
|
||||
case Ops.REDUCE_AXIS | Ops.WMMA:
|
||||
axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
||||
assert isinstance(axis_arg, tuple) and all(isinstance(x, int) for x in axis_arg), f"invalid type for axis: {axis_arg}"
|
||||
shape = src_sts[0].reduce(axis_arg)
|
||||
case _: shape = src_sts[0].shape
|
||||
return ShapeTracker.from_shape(shape)
|
||||
|
||||
@@ -286,16 +286,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop syntactic sugar ***
|
||||
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
|
||||
return unwrap(self.st)
|
||||
@property
|
||||
def axis_arg(self) -> tuple[int, ...]:
|
||||
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
||||
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
||||
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
||||
return ret
|
||||
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
@@ -501,7 +491,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def buffer(self) -> Buffer|MultiBuffer:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
if self is not self.base:
|
||||
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
||||
assert self.op is Ops.RESHAPE, f"can only be RESHAPE {self}"
|
||||
return self.src[0].buffer
|
||||
if self.op is Ops.MSELECT:
|
||||
ret = self.src[0].buffer
|
||||
@@ -992,7 +982,7 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")) and not int(os.getenv("SQTT", "0")):
|
||||
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
|
||||
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
|
||||
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), "../", "viz", "serve.py")] + args)
|
||||
os.execv(sys.executable, [sys.executable] + [pathlib.Path(__file__).resolve().parent.parent / "viz" / "serve.py"] + args)
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user