update shifts spec (#7468)

* update shifts spec

* hotfix: old style
This commit is contained in:
ignaciosica
2024-11-01 13:40:41 -03:00
committed by GitHub
parent 18bd98c203
commit 9c832483f2

View File

@@ -788,8 +788,10 @@ spec = PatternMatcher([
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
# and SHL/SHR, the shift distance is an int
(UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHL), lambda alu,x: alu.dtype == x.dtype),
(UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHR), lambda alu,x: alu.dtype == x.dtype),
(UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),