diff --git a/extra/assembly/amd/pcode_transform.py b/extra/assembly/amd/pcode_transform.py index b95e9d372d..cbddcab675 100644 --- a/extra/assembly/amd/pcode_transform.py +++ b/extra/assembly/amd/pcode_transform.py @@ -335,13 +335,10 @@ pcode_pm = PatternMatcher([ (UPat(Ops.DEFINE_VAR, name='u'), _track_var), (UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='u'), _prop_var), (UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_VAR, dtype=dtypes.void, name='lhs'), UPat.var('rhs'))), _prop_assign), - # Propagate dtype for ASSIGN from rhs, or infer rhs dtype from lhs if rhs is void + # Propagate dtype for ASSIGN from rhs, or propagate lhs type to void DEFINE_VAR rhs (UPat(Ops.ASSIGN, dtype=dtypes.void, src=(UPat.var('lhs'), UPat.var('rhs'))), lambda lhs, rhs: UOp(Ops.ASSIGN, rhs.dtype, (lhs, rhs)) if rhs.dtype != dtypes.void else - UOp(Ops.ASSIGN, lhs.dtype, (lhs, rhs.replace(dtype=lhs.dtype))) if lhs.dtype != dtypes.void else None), - # Fix ASSIGN type mismatch: ensure dtype matches rhs, cast rhs if needed to match lhs - (UPat(Ops.ASSIGN, src=(UPat.var('lhs'), UPat.var('rhs')), name='a'), - lambda a, lhs, rhs: UOp(Ops.ASSIGN, lhs.dtype, (lhs, UOp(Ops.CAST, lhs.dtype, (rhs,)))) if lhs.dtype != dtypes.void and rhs.dtype != dtypes.void and a.dtype != rhs.dtype else None), + UOp(Ops.ASSIGN, lhs.dtype, (lhs, rhs.replace(dtype=lhs.dtype))) if lhs.dtype != dtypes.void and rhs.op == Ops.DEFINE_VAR else None), # Dtype propagation for void-typed ops (UPat((Ops.ADD, Ops.SUB, Ops.MUL, Ops.FDIV, Ops.AND, Ops.OR, Ops.XOR, Ops.SHL, Ops.SHR, Ops.MOD, Ops.POW), dtype=dtypes.void, src=(UPat.var('l'), UPat.var('r')), name='__OP__'), _prop_binop),