remove UOps lt pattern of booleans (#5666)

covered by the generic lt fold pattern
This commit is contained in:
chenyu
2024-07-23 20:11:21 -04:00
committed by GitHub
parent e196640d71
commit ea99efe815
2 changed files with 3 additions and 6 deletions

View File

@@ -213,9 +213,6 @@ constant_folder = PatternMatcher([
(UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: root.const(x.arg)),
# max -2147483648
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
# bool < False is always false, True < bool is always false # TODO: replace these with generic cmp
(UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
# a conditional with the same results either way is a noop, also fold const conditionals
(UOp.var().where(UOp.var("val"), UOp.var("val")), lambda val: val),
(UOp.cvar('gate').where(UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
import functools, itertools, math
import functools, itertools
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
@@ -102,12 +102,12 @@ class UOp:
def vmax(self) -> UOp:
if self.op is UOps.DEFINE_VAR: return self.src[1]
if self.op is UOps.CONST: return self
return UOp.const(dtypes.float, math.inf)
return self.const(dtypes.max(cast(DType, self.dtype)))
@functools.cached_property
def vmin(self) -> UOp:
if self.op is UOps.DEFINE_VAR: return self.src[0]
if self.op is UOps.CONST: return self
return UOp.const(dtypes.float, -math.inf)
return self.const(dtypes.min(cast(DType, self.dtype)))
class UPat:
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,