fix maximum/where Scalar casting (#3194)

* init

* test: added dtype tests for maximum

* fix: seperate maximum const and maximum tensors

* fix: del useless line

* fix: some dtypes

* CODE GOLF: we golfing at mar-a-lago golf club tonight boyyyys

* fix: add lil helper function

* fix: some test refactoring

* done

* sike: not done yet lol

* wtf I missed an assert, am I drunk

* yeah idk

* fix: line save from redundant check

* revert: line save

* fix: simplify test_broadcast cuz I'm stumped

* change some test name

* fix: bool max bool  works

* test: add a maximum bool test

* test: make sure minimum also works with bool

* fix: something like this? :s

* fix: maybe this?

* fix: how about this? tighter check

* fix: this.

* revert: nvm mul(0.5) and div(2) has the same kernel for backward

* fix: .is_floating_point() xD

* revert: maximum and minimum and add cast

* fix: cover negative const case in test

* fix: use eq because I don't understand clang :D

* WHOOOOPS
This commit is contained in:
geohotstan
2024-01-26 01:26:04 +08:00
committed by GitHub
parent 3628bea910
commit d0e116c6d6
3 changed files with 52 additions and 38 deletions

View File

@@ -397,40 +397,12 @@ class TestAutoCastType(unittest.TestCase):
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@given(strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_broadcast_float(self, default_float):
dtypes.default_float = default_float
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2.3).dtype == dtypes.float16
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2.3).dtype == dtypes.bfloat16
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2.3).dtype == dtypes.float32
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2.3).dtype == dtypes.float64
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]))
def test_broadcast_int(self, default_int):
dtypes.default_int = default_int
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2).dtype == dtypes.default_int
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2).dtype == dtypes.int
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2).dtype == dtypes.int8
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2).dtype == dtypes.uint64
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2).dtype == dtypes.float16
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2).dtype == dtypes.bfloat16
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2).dtype == dtypes.float32
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2).dtype == dtypes.float64
def test_broadcast_bool(self):
if Device.DEFAULT != "WEBGPU":
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
assert (Tensor([0, 1], dtype=dtypes.int) + True).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8) + True).dtype == dtypes.int8
assert (Tensor([0, 1], dtype=dtypes.uint64) + True).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16) + True).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16) + True).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32) + True).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64) + True).dtype == dtypes.float64
@given(strat.sampled_from(core_dtypes))
def test_broadcast_scalar(self, dt):
assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
@@ -466,5 +438,39 @@ class TestAutoCastType(unittest.TestCase):
def test_matmul(self, dt1, dt2):
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_where_no_scalar(self, dt1, dt2):
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
@given(strat.sampled_from(core_dtypes))
def test_where_one_scalar(self, dt):
t = Tensor(2, dtype=dt)
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
self.check_where_alternate_input_other(t, True, dt)
def test_where_two_scalars(self):
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
self.check_where_alternate_input_other(3, True, dtypes.default_int)
self.check_where_alternate_input_other(False, True, dtypes.bool)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_maximum(self, dt1, dt2):
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@given(strat.sampled_from(core_dtypes))
def test_maximum_const(self, dt):
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
if __name__ == '__main__':
unittest.main()

View File

@@ -262,11 +262,17 @@ class TestOps(unittest.TestCase):
def test_maximum(self):
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]])
helper_test_op(None, torch.maximum, Tensor.maximum, vals=np.array([[1, 0, 3, 4], [1, 2, 3, 0]], dtype=np.int32), forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True)
def test_minimum(self):
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True)
def test_add(self):
helper_test_op([(45,68), (45,68)], lambda x,y: x+y)

View File

@@ -837,11 +837,13 @@ class Tensor:
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
# TODO: this implicitly changes dtype with /2
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor:
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
if isinstance(input_, Tensor): input_, other = input_._broadcasted(other)
elif isinstance(other, Tensor): other, input_ = other._broadcasted(input_)
x_,y = self._broadcasted(input_, match_dtype=False)
x,z = x_._broadcasted(other, match_dtype=False)
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))