From 9c832483f275c41b75622676ad0bf64f8b827d69 Mon Sep 17 00:00:00 2001 From: ignaciosica Date: Fri, 1 Nov 2024 13:40:41 -0300 Subject: [PATCH] update shifts spec (#7468) * update shifts spec * hotfix: old style --- tinygrad/ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 790254b13c..68480e510b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -788,8 +788,10 @@ spec = PatternMatcher([ (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), # and SHL/SHR, the shift distance is an int - (UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHL), lambda alu,x: alu.dtype == x.dtype), - (UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHR), lambda alu,x: alu.dtype == x.dtype), + (UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL), + lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), + (UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR), + lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),