diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 1689f45ab8..683135f7bb 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -492,7 +492,7 @@ pm_long_decomp = PatternMatcher([ l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None), (UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x: l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt)) - if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag]), + if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag] if x.tag is not None else None), (UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))), (UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x: UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))