remove_tags and _remove_all_tags are the same [pr] (#15819)

also other small UOp method cleanups
This commit is contained in:
chenyu
2026-04-19 21:37:49 -04:00
committed by GitHub
parent a1696e8413
commit 538841d1f2
4 changed files with 6 additions and 8 deletions

View File

@@ -1,5 +1,5 @@
import math
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
from tinygrad.helpers import dedup, get_contraction
from tinygrad.dtype import dtypes, AddrSpace, Invalid
from tinygrad.renderer import Renderer
@@ -35,7 +35,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
if len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
# try to split up dims: (a,) -> (b, c)
if limited == dims: limited = _split_dims(dims, max_sizes)
raw_idxs = [UOp(Ops.SPECIAL, dtypes.weakint, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
raw_idxs = [UOp.special(s, f"{prefix}{i}") for i,s in enumerate(limited)]
if len(limited) < len(dims):
ret = []
if (contraction:=get_contraction(dims, limited)) is None: raise RuntimeError(f"get_contraction should not be None {dims=} {limited=}")

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import math, itertools
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.uop.ops import Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, remove_all_tags
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes
@@ -12,8 +12,6 @@ from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
class Scheduler:
def __init__(self, ast:UOp, ren:Renderer):
self.ast, self.ren = ast, ren
@@ -211,7 +209,7 @@ class Scheduler:
self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1),
altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)},
name=f"swap {rng.arg[:-1]} {altrng.arg[:-1]}")
self.ast = graph_rewrite(self.ast, remove_tags, name="swap remove tags")
self.ast = graph_rewrite(self.ast, remove_all_tags, name="swap remove tags")
else:
raise KernelOptError(f"unsupported opt {opt.op}")

View File

@@ -804,7 +804,7 @@ class Tensor(OpMixin):
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=device).requires_grad_(requires_grad)
if requires_grad:
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device).requires_grad_(requires_grad)
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
return super().full_like(fill_value, dtype)
def rand_like(self, **kwargs) -> Tensor:
"""

View File

@@ -1480,7 +1480,7 @@ def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lowe
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])])
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def gate_kernel_sink(x:UOp) -> bool:
if x.op is Ops.LINEAR: return False