mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
more generic lt folding (#5665)
This commit is contained in:
@@ -17,7 +17,8 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.bigint, sel
|
||||
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
||||
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
||||
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
||||
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp(UOps.DEFINE_VAR, dtypes.int32, (), self),
|
||||
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
|
||||
UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ constant_folder = PatternMatcher([
|
||||
(UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: root.const(x.arg)),
|
||||
# max -2147483648
|
||||
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
||||
# bool < False is always false, True < bool is always false
|
||||
# bool < False is always false, True < bool is always false # TODO: replace these with generic cmp
|
||||
(UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
|
||||
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
@@ -239,8 +239,8 @@ constant_folder = PatternMatcher([
|
||||
(UOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
(UOp.var('x') - UOp.var('x'), lambda x: x.const(0)), # x-x -> 0
|
||||
# lt folding
|
||||
(UOp.var('x').lt(UOp.cvar('c')),
|
||||
lambda x,c: UOp.const(dtypes.bool, True) if x.vmax.arg < c.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= c.arg else None),
|
||||
(UOp.var('x').lt(UOp.var('y')),
|
||||
lambda x,y: UOp.const(dtypes.bool, True) if x.vmax.arg < y.vmin.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= y.vmax.arg else None),
|
||||
# ** load/store folding **
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
||||
# ** two stage add/sub folding **
|
||||
|
||||
@@ -73,8 +73,8 @@ class UOp:
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _const(dtype:Optional[DType], b:ConstType|Variable):
|
||||
# TODO: min/max for const Variable?
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
|
||||
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
@staticmethod
|
||||
def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
|
||||
@@ -101,10 +101,12 @@ class UOp:
|
||||
@functools.cached_property
|
||||
def vmax(self) -> UOp:
|
||||
if self.op is UOps.DEFINE_VAR: return self.src[1]
|
||||
if self.op is UOps.CONST: return self
|
||||
return UOp.const(dtypes.float, math.inf)
|
||||
@functools.cached_property
|
||||
def vmin(self) -> UOp:
|
||||
if self.op is UOps.DEFINE_VAR: return self.src[0]
|
||||
if self.op is UOps.CONST: return self
|
||||
return UOp.const(dtypes.float, -math.inf)
|
||||
|
||||
class UPat:
|
||||
|
||||
Reference in New Issue
Block a user