mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
narrow return type of bool, int, float on UOp [pr] (#7246)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
@@ -234,12 +234,11 @@ class UOp(MathTrait):
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is UOps.CONST else ret
|
||||
def _eval(self, dtype, expected_type) -> ConstType:
|
||||
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
simple_self = self.simplify()
|
||||
vmin, vmax = simple_self._min_max
|
||||
vmin, vmax = (simple_self:=self.simplify())._min_max
|
||||
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
|
||||
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
|
||||
assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
|
||||
return vmin
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
|
||||
Reference in New Issue
Block a user