mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
add UOp bound for BinaryOps.CMPLT (#5833)
and remove the redundant lt folding rule
This commit is contained in:
@@ -210,9 +210,6 @@ constant_folder = PatternMatcher([
|
||||
(NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
(NOp.var('x') - NOp.var('x'), lambda x: x.const(0)), # x-x -> 0
|
||||
(UPat(op=UOps.ALU, name='x'), lambda x: x.const(x.vmin.arg) if x.op is not UOps.CONST and x.vmin.arg == x.vmax.arg else None),
|
||||
# lt folding
|
||||
(NOp.var('x').lt(NOp.var('y')),
|
||||
lambda x,y: NOp.const(dtypes.bool, True) if x.vmax.arg < y.vmin.arg else NOp.const(dtypes.bool, False) if x.vmin.arg >= y.vmax.arg else None),
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((NOp.cvar('c0',dtypes.int)*NOp.var('x')).lt(NOp.cvar('c1',dtypes.int)),
|
||||
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
|
||||
|
||||
@@ -119,6 +119,8 @@ class UOp:
|
||||
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST and s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
|
||||
if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, True), UOp.const(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
|
||||
(UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
|
||||
return None, None
|
||||
|
||||
@dataclass(frozen=True, repr=False) # reuse repr from UOp
|
||||
|
||||
Reference in New Issue
Block a user