mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove a few rules in pm_lower_index_dtype [pr] (#13317)
This commit is contained in:
@@ -1265,15 +1265,12 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.index),), name="u"), lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.index)),
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.index)),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
|
||||
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.index else s for s in n.src))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user