more generic lt folding (#5665)

This commit is contained in:
chenyu
2024-07-23 19:50:59 -04:00
committed by GitHub
parent 7c8fe0fe47
commit e196640d71
3 changed files with 9 additions and 6 deletions

View File

@@ -17,7 +17,8 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.bigint, sel
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp(UOps.DEFINE_VAR, dtypes.int32, (), self),
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }

View File

@@ -213,7 +213,7 @@ constant_folder = PatternMatcher([
(UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: root.const(x.arg)),
# max -2147483648
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
# bool < False is always false, True < bool is always false
# bool < False is always false, True < bool is always false # TODO: replace these with generic cmp
(UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
# a conditional with the same results either way is a noop, also fold const conditionals
@@ -239,8 +239,8 @@ constant_folder = PatternMatcher([
(UOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
(UOp.var('x') - UOp.var('x'), lambda x: x.const(0)), # x-x -> 0
# lt folding
(UOp.var('x').lt(UOp.cvar('c')),
lambda x,c: UOp.const(dtypes.bool, True) if x.vmax.arg < c.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= c.arg else None),
(UOp.var('x').lt(UOp.var('y')),
lambda x,y: UOp.const(dtypes.bool, True) if x.vmax.arg < y.vmin.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= y.vmax.arg else None),
# ** load/store folding **
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
# ** two stage add/sub folding **

View File

@@ -73,8 +73,8 @@ class UOp:
@staticmethod
@functools.lru_cache(maxsize=None)
def _const(dtype:Optional[DType], b:ConstType|Variable):
# TODO: min/max for const Variable?
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
@staticmethod
def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
@@ -101,10 +101,12 @@ class UOp:
@functools.cached_property
def vmax(self) -> UOp:
if self.op is UOps.DEFINE_VAR: return self.src[1]
if self.op is UOps.CONST: return self
return UOp.const(dtypes.float, math.inf)
@functools.cached_property
def vmin(self) -> UOp:
if self.op is UOps.DEFINE_VAR: return self.src[0]
if self.op is UOps.CONST: return self
return UOp.const(dtypes.float, -math.inf)
class UPat: