narrow return type of bool, int, float on UOp [pr] (#7246)

This commit is contained in:
chenyu
2024-10-23 21:06:43 -04:00
committed by GitHub
parent 9f370cccb3
commit 451c043552

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
@@ -234,12 +234,11 @@ class UOp(MathTrait):
with Context(TRACK_MATCH_STATS=0):
return graph_rewrite(self, symbolic)
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is UOps.CONST else ret
def _eval(self, dtype, expected_type) -> ConstType:
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
simple_self = self.simplify()
vmin, vmax = simple_self._min_max
vmin, vmax = (simple_self:=self.simplify())._min_max
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
return vmin
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)