diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 303897c8de..60b28c8397 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -1,5 +1,5 @@ import math -from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop +from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType from tinygrad.helpers import dedup, get_contraction from tinygrad.dtype import dtypes, AddrSpace, Invalid from tinygrad.renderer import Renderer @@ -35,7 +35,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No if len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}") # try to split up dims: (a,) -> (b, c) if limited == dims: limited = _split_dims(dims, max_sizes) - raw_idxs = [UOp(Ops.SPECIAL, dtypes.weakint, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] + raw_idxs = [UOp.special(s, f"{prefix}{i}") for i,s in enumerate(limited)] if len(limited) < len(dims): ret = [] if (contraction:=get_contraction(dims, limited)) is None: raise RuntimeError(f"get_contraction should not be None {dims=} {limited=}") diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index dfffc27dbb..cfda8a5449 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -2,7 +2,7 @@ from __future__ import annotations import math, itertools from collections import defaultdict from typing import cast, Final -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp +from tinygrad.uop.ops import Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, remove_all_tags from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos from tinygrad.device import Buffer from tinygrad.dtype import dtypes @@ -12,8 +12,6 @@ from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer -remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) - class Scheduler: def __init__(self, ast:UOp, ren:Renderer): self.ast, self.ren = ast, ren @@ -211,7 +209,7 @@ class Scheduler: self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1), altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)}, name=f"swap {rng.arg[:-1]} {altrng.arg[:-1]}") - self.ast = graph_rewrite(self.ast, remove_tags, name="swap remove tags") + self.ast = graph_rewrite(self.ast, remove_all_tags, name="swap remove tags") else: raise KernelOptError(f"unsupported opt {opt.op}") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2d8039b53c..c4db3bf101 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -804,7 +804,7 @@ class Tensor(OpMixin): return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=device).requires_grad_(requires_grad) if requires_grad: return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device).requires_grad_(requires_grad) - return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype) + return super().full_like(fill_value, dtype) def rand_like(self, **kwargs) -> Tensor: """ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 60750ea91f..cd5e55bc8d 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1480,7 +1480,7 @@ def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lowe _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) _pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])]) -_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) +remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) def gate_kernel_sink(x:UOp) -> bool: if x.op is Ops.LINEAR: return False