mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove_tags and _remove_all_tags are the same [pr] (#15819)
also other small UOp method cleanups
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
||||
from tinygrad.helpers import dedup, get_contraction
|
||||
from tinygrad.dtype import dtypes, AddrSpace, Invalid
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -35,7 +35,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||
if len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
# try to split up dims: (a,) -> (b, c)
|
||||
if limited == dims: limited = _split_dims(dims, max_sizes)
|
||||
raw_idxs = [UOp(Ops.SPECIAL, dtypes.weakint, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
|
||||
raw_idxs = [UOp.special(s, f"{prefix}{i}") for i,s in enumerate(limited)]
|
||||
if len(limited) < len(dims):
|
||||
ret = []
|
||||
if (contraction:=get_contraction(dims, limited)) is None: raise RuntimeError(f"get_contraction should not be None {dims=} {limited=}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import math, itertools
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
|
||||
from tinygrad.uop.ops import Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, remove_all_tags
|
||||
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -12,8 +12,6 @@ from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, ast:UOp, ren:Renderer):
|
||||
self.ast, self.ren = ast, ren
|
||||
@@ -211,7 +209,7 @@ class Scheduler:
|
||||
self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1),
|
||||
altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)},
|
||||
name=f"swap {rng.arg[:-1]} {altrng.arg[:-1]}")
|
||||
self.ast = graph_rewrite(self.ast, remove_tags, name="swap remove tags")
|
||||
self.ast = graph_rewrite(self.ast, remove_all_tags, name="swap remove tags")
|
||||
else:
|
||||
raise KernelOptError(f"unsupported opt {opt.op}")
|
||||
|
||||
|
||||
@@ -804,7 +804,7 @@ class Tensor(OpMixin):
|
||||
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=device).requires_grad_(requires_grad)
|
||||
if requires_grad:
|
||||
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device).requires_grad_(requires_grad)
|
||||
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
|
||||
return super().full_like(fill_value, dtype)
|
||||
|
||||
def rand_like(self, **kwargs) -> Tensor:
|
||||
"""
|
||||
|
||||
@@ -1480,7 +1480,7 @@ def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lowe
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])])
|
||||
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def gate_kernel_sink(x:UOp) -> bool:
|
||||
if x.op is Ops.LINEAR: return False
|
||||
|
||||
Reference in New Issue
Block a user