Lower index dtype spec fix (#12337)

* new pm_lower_index_dtype

* load_store_indexing after index lowering

* shorten line

* seperate rule for long removal

* fix test

* fix index_to_concrete_int

* minor fixes

* add sink there

* update types in linearizer test
This commit is contained in:
Sieds Lykles
2025-09-30 04:26:50 +02:00
committed by GitHub
parent c38f6ce140
commit d55d829635
6 changed files with 36 additions and 36 deletions

View File

@@ -15,18 +15,18 @@ class TestLinearizerFailure(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL")
def test_failure_beam_mnist(self):
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(4014080), arg=0, src=())
c1 = UOp.range(UOp.const(dtypes.int, 512), 0, AxisType.GLOBAL)
c2 = UOp.range(UOp.const(dtypes.int, 784), 1, AxisType.GLOBAL)
c3 = UOp.range(UOp.const(dtypes.int, 10), 3, AxisType.GLOBAL)
c1 = UOp.range(UOp.const(dtypes.index, 512), 0, AxisType.GLOBAL)
c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL)
c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL)
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
c5 = c4.index(c1, UOp.const(dtypes.bool, True)).load()
c6 = UOp.range(UOp.const(dtypes.int, 6000), 1004, AxisType.REDUCE)
c7 = UOp.range(UOp.const(dtypes.int, 3750), 2006, AxisType.REDUCE)
c8 = UOp.range(UOp.const(dtypes.int, 16), 2007, AxisType.GROUP_REDUCE)
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))).load()
c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE)
c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE)
c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE)
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=())
c10 = c9.index((((c3*UOp.const(dtypes.int, 4704000))+c2)+(c6*UOp.const(dtypes.int, 784))), UOp.const(dtypes.bool, True)).load()
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.int, 6000))+c6)+((c7*UOp.const(dtypes.int, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.int, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
c12 = c0.index((((c1*UOp.const(dtypes.int, 7840))+(c2*UOp.const(dtypes.int, 10)))+c3), UOp.const(dtypes.bool, True)).store(c11, c1, c2, c3)
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))).load()
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11, c1, c2, c3)
ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None))
_ = get_program(ast, Device["METAL"].renderer)

View File

@@ -581,7 +581,7 @@ class TestUOpGraph(unittest.TestCase):
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))

View File

@@ -95,7 +95,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
# lower the index dtype to a concrete int
ret.append(RewriteStep(load_store_indexing+pm_lower_index_dtype, lambda _: opts.device, name="lower all index dtypes"))
ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
# optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))

View File

@@ -53,6 +53,8 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp,
load_store_indexing = PatternMatcher([
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
# simplify away long after index has been lowered
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
# delete_redundant_gates (after expand)

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
from tinygrad.gradient import compute_gradient
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop, \
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, \
srender
from tinygrad.uop.spec import tensor_uop_spec, type_verify
from tinygrad.device import Device, Buffer
@@ -143,7 +143,7 @@ class Tensor(MathTrait):
if isinstance(data, UOp):
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
if data.dtype==dtypes.index: data = index_to_concrete_int(data)
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
var, val = data.unbind() # type: ignore
# give the bound constant a device

View File

@@ -1073,31 +1073,29 @@ def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
pm_lower_index_dtype = PatternMatcher([
# There are no Unary ops at this point in symbolic, those are introduced later
(UPat(GroupOp.Binary, dtypes.index, name="u", src=(UPat.var("x"), UPat.var("y"))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt))),
# comparison ops might now have different dtypes in their sources
(UPat(GroupOp.Comparison, name="u", src=(UPat.var("x",dtypes.ints), UPat.var("y", dtypes.ints))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)).alu(u.op, y.cast(dt)) if x.dtype!=y.dtype else None),
(UPat(Ops.WHERE, dtypes.index, src=(UPat(), UPat.var("x"), UPat(Ops.CONST, arg=Invalid)), name="u"), lambda u,x: u.replace(dtype=x.dtype)),
(UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y"))), lambda cond,x,y:
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt))),
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u)) if u.arg != Invalid else None),
(UPat((Ops.RANGE,), dtype=dtypes.index, src=(UPat.var("end")), name="r"), lambda ctx,r,end:
r.replace(dtype=(dt:=select_dtype(r)), src=(end.cast(dt),))),
(UPat(Ops.CAST, dtype=dtypes.index, src=(UPat.var("x", dtypes.ints),), name="u"), lambda u,x: x),
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(
dtype=(dt:=least_upper_dtype(*[x.dtype for x in u.src])).vec(u.dtype.count), src=tuple(x.cast(dt) for x in u.src))),
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=(dt:=(dtypes.long if any(v.overflows(dtypes.int) for v in u.src)
else dtypes.long)).vec(u.dtype.count),src=tuple(x.cast(dt) for x in u.src))),
(UPat((Ops.SPECIAL,Ops.DEFINE_VAR), dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int)),
(UPat((Ops.BIND), dtypes.index, name="u"), lambda u: u.replace(dtype=u.src[0].dtype)),
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)),
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
(UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda cond,x,y:
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.index)),
(UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.index)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.index)),
(UPat(Ops.VECTORIZE, src=UPat().cast(dtypes.index), name="v"),
lambda v: v.replace(dtype=(dt:=select_dtype(v)), src=tuple(s.src[0].cast(dt.scalar()) for s in v.src)).cast(dtypes.index)),
# special can only be int32
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.index),), name="u"), lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.index)),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.index)),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)),
# remove hanging cast
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
])
def index_to_concrete_int(u:UOp): return graph_rewrite(u, pm_lower_index_dtype)
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])