fix long_decomp with None tag (#14707)

fixed `DEBUG=2 WEBGPU=1 python -m pytest test/null/test_tensor.py::TestIdxUpcast::test_int64_unsupported_overflow_sym`
This commit is contained in:
chenyu
2026-02-12 09:31:52 -05:00
committed by GitHub
parent 557134e1c7
commit 212789e31e

View File

@@ -492,7 +492,7 @@ pm_long_decomp = PatternMatcher([
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None),
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag]),
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag] if x.tag is not None else None),
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))