mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Lower index dtype spec fix (#12337)
* new pm_lower_index_dtype * load_store_indexing after index lowering * shorten line * seperate rule for long removal * fix test * fix index_to_concrete_int * minor fixes * add sink there * update types in linearizer test
This commit is contained in:
@@ -15,18 +15,18 @@ class TestLinearizerFailure(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL")
|
||||
def test_failure_beam_mnist(self):
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(4014080), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.int, 512), 0, AxisType.GLOBAL)
|
||||
c2 = UOp.range(UOp.const(dtypes.int, 784), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.int, 10), 3, AxisType.GLOBAL)
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 512), 0, AxisType.GLOBAL)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL)
|
||||
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
|
||||
c5 = c4.index(c1, UOp.const(dtypes.bool, True)).load()
|
||||
c6 = UOp.range(UOp.const(dtypes.int, 6000), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.int, 3750), 2006, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.int, 16), 2007, AxisType.GROUP_REDUCE)
|
||||
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))).load()
|
||||
c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE)
|
||||
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=())
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.int, 4704000))+c2)+(c6*UOp.const(dtypes.int, 784))), UOp.const(dtypes.bool, True)).load()
|
||||
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.int, 6000))+c6)+((c7*UOp.const(dtypes.int, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.int, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
|
||||
c12 = c0.index((((c1*UOp.const(dtypes.int, 7840))+(c2*UOp.const(dtypes.int, 10)))+c3), UOp.const(dtypes.bool, True)).store(c11, c1, c2, c3)
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))).load()
|
||||
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
|
||||
c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11, c1, c2, c3)
|
||||
ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None))
|
||||
_ = get_program(ast, Device["METAL"].renderer)
|
||||
|
||||
|
||||
@@ -581,7 +581,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))
|
||||
|
||||
@@ -95,7 +95,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
ret.append(RewriteStep(load_store_indexing+pm_lower_index_dtype, lambda _: opts.device, name="lower all index dtypes"))
|
||||
ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
|
||||
|
||||
# optional pre matcher
|
||||
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
||||
|
||||
@@ -53,6 +53,8 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp,
|
||||
load_store_indexing = PatternMatcher([
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
|
||||
# simplify away long after index has been lowered
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
|
||||
# delete_redundant_gates (after expand)
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop, \
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, \
|
||||
srender
|
||||
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.device import Device, Buffer
|
||||
@@ -143,7 +143,7 @@ class Tensor(MathTrait):
|
||||
if isinstance(data, UOp):
|
||||
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
if data.dtype==dtypes.index: data = index_to_concrete_int(data)
|
||||
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
|
||||
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
|
||||
var, val = data.unbind() # type: ignore
|
||||
# give the bound constant a device
|
||||
|
||||
@@ -1073,31 +1073,29 @@ def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
pm_lower_index_dtype = PatternMatcher([
|
||||
# There are no Unary ops at this point in symbolic, those are introduced later
|
||||
(UPat(GroupOp.Binary, dtypes.index, name="u", src=(UPat.var("x"), UPat.var("y"))), lambda u,x,y:
|
||||
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt))),
|
||||
# comparison ops might now have different dtypes in their sources
|
||||
(UPat(GroupOp.Comparison, name="u", src=(UPat.var("x",dtypes.ints), UPat.var("y", dtypes.ints))), lambda u,x,y:
|
||||
x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)).alu(u.op, y.cast(dt)) if x.dtype!=y.dtype else None),
|
||||
(UPat(Ops.WHERE, dtypes.index, src=(UPat(), UPat.var("x"), UPat(Ops.CONST, arg=Invalid)), name="u"), lambda u,x: u.replace(dtype=x.dtype)),
|
||||
(UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y"))), lambda cond,x,y:
|
||||
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt))),
|
||||
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u)) if u.arg != Invalid else None),
|
||||
(UPat((Ops.RANGE,), dtype=dtypes.index, src=(UPat.var("end")), name="r"), lambda ctx,r,end:
|
||||
r.replace(dtype=(dt:=select_dtype(r)), src=(end.cast(dt),))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.index, src=(UPat.var("x", dtypes.ints),), name="u"), lambda u,x: x),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(
|
||||
dtype=(dt:=least_upper_dtype(*[x.dtype for x in u.src])).vec(u.dtype.count), src=tuple(x.cast(dt) for x in u.src))),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=(dt:=(dtypes.long if any(v.overflows(dtypes.int) for v in u.src)
|
||||
else dtypes.long)).vec(u.dtype.count),src=tuple(x.cast(dt) for x in u.src))),
|
||||
(UPat((Ops.SPECIAL,Ops.DEFINE_VAR), dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int)),
|
||||
(UPat((Ops.BIND), dtypes.index, name="u"), lambda u: u.replace(dtype=u.src[0].dtype)),
|
||||
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda u,x,y:
|
||||
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)),
|
||||
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
|
||||
(UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda cond,x,y:
|
||||
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.index)),
|
||||
(UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.index)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.index)),
|
||||
(UPat(Ops.VECTORIZE, src=UPat().cast(dtypes.index), name="v"),
|
||||
lambda v: v.replace(dtype=(dt:=select_dtype(v)), src=tuple(s.src[0].cast(dt.scalar()) for s in v.src)).cast(dtypes.index)),
|
||||
# special can only be int32
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.index),), name="u"), lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.index)),
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.index)),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
|
||||
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)),
|
||||
# remove hanging cast
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
|
||||
])
|
||||
def index_to_concrete_int(u:UOp): return graph_rewrite(u, pm_lower_index_dtype)
|
||||
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user