mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix round
This commit is contained in:
@@ -17,6 +17,7 @@ def Identity(x: Tensor): return x
|
||||
def Add(x: Tensor, other: Tensor, broadcast=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
|
||||
def Sub(x: Union[Tensor, Any], other: Tensor): return x - other # some test has input as int
|
||||
def Div(x: Tensor, other: Tensor): return x / other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else x.div(other).floor() # TODO: this has dtype issues
|
||||
# TODO get rid of casts
|
||||
def Pow(x: Tensor, other: Tensor): return x.float() ** other.float()
|
||||
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
|
||||
@@ -406,7 +407,7 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
|
||||
if equidistant_case == "round_to_even":
|
||||
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
|
||||
x_ceil_fraction = x.ceil()/2
|
||||
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
|
||||
cond_ceil_even = (x_ceil_fraction.ceil() == x_ceil_fraction).float()
|
||||
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
|
||||
x = (x > b).where(b+1-n, b-n)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user