diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 790254b13c..68480e510b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -788,8 +788,10 @@ spec = PatternMatcher([ (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), # and SHL/SHR, the shift distance is an int - (UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHL), lambda alu,x: alu.dtype == x.dtype), - (UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHR), lambda alu,x: alu.dtype == x.dtype), + (UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL), + lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), + (UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR), + lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),