fix round

This commit is contained in:
qazal
2023-12-22 22:45:44 +02:00
parent 7ebac1018f
commit d08a556a75

View File

@@ -17,6 +17,7 @@ def Identity(x: Tensor): return x
def Add(x: Tensor, other: Tensor, broadcast=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
def Sub(x: Union[Tensor, Any], other: Tensor): return x - other # some test has input as int
def Div(x: Tensor, other: Tensor): return x / other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else x.div(other).floor() # TODO: this has dtype issues
# TODO get rid of casts
def Pow(x: Tensor, other: Tensor): return x.float() ** other.float()
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
@@ -406,7 +407,7 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
if equidistant_case == "round_to_even":
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
x_ceil_fraction = x.ceil()/2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
cond_ceil_even = (x_ceil_fraction.ceil() == x_ceil_fraction).float()
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
x = (x > b).where(b+1-n, b-n)
return x