assign fix clear

This commit is contained in:
George Hotz
2026-01-11 16:24:48 +09:00
parent 44752804f2
commit 63b41155a7

View File

@@ -335,13 +335,10 @@ pcode_pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name='u'), _track_var),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='u'), _prop_var),
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='lhs'), UPat.var('rhs'))), _prop_assign),
# Propagate dtype for ASSIGN from rhs, or infer rhs dtype from lhs if rhs is void
# Propagate dtype for ASSIGN from rhs, or propagate lhs type to void DEFINE_VAR rhs
(UPat(Ops.ASSIGN, dtype=dtypes.void, src=(UPat.var('lhs'), UPat.var('rhs'))),
lambda lhs, rhs: UOp(Ops.ASSIGN, rhs.dtype, (lhs, rhs)) if rhs.dtype != dtypes.void else
UOp(Ops.ASSIGN, lhs.dtype, (lhs, rhs.replace(dtype=lhs.dtype))) if lhs.dtype != dtypes.void else None),
# Fix ASSIGN type mismatch: ensure dtype matches rhs, cast rhs if needed to match lhs
(UPat(Ops.ASSIGN, src=(UPat.var('lhs'), UPat.var('rhs')), name='a'),
lambda a, lhs, rhs: UOp(Ops.ASSIGN, lhs.dtype, (lhs, UOp(Ops.CAST, lhs.dtype, (rhs,)))) if lhs.dtype != dtypes.void and rhs.dtype != dtypes.void and a.dtype != rhs.dtype else None),
UOp(Ops.ASSIGN, lhs.dtype, (lhs, rhs.replace(dtype=lhs.dtype))) if lhs.dtype != dtypes.void and rhs.op == Ops.DEFINE_VAR else None),
# Dtype propagation for void-typed ops
(UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR, Ops.SHL, Ops.SHR, Ops.MOD, Ops.POW),
dtype=dtypes.void, src=(UPat.var('l'), UPat.var('r')), name='__OP__'), _prop_binop),