mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix long_decomp with None tag (#14707)
fixed `DEBUG=2 WEBGPU=1 python -m pytest test/null/test_tensor.py::TestIdxUpcast::test_int64_unsupported_overflow_sym`
This commit is contained in:
@@ -492,7 +492,7 @@ pm_long_decomp = PatternMatcher([
|
||||
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None),
|
||||
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag]),
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag] if x.tag is not None else None),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
|
||||
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
|
||||
|
||||
Reference in New Issue
Block a user