fix maximum/where Scalar casting (#3194)

* init

* test: added dtype tests for maximum

* fix: seperate maximum const and maximum tensors

* fix: del useless line

* fix: some dtypes

* CODE GOLF: we golfing at mar-a-lago golf club tonight boyyyys

* fix: add lil helper function

* fix: some test refactoring

* done

* sike: not done yet lol

* wtf I missed an assert, am I drunk

* yeah idk

* fix: line save from redundant check

* revert: line save

* fix: simplify test_broadcast cuz I'm stumped

* change some test name

* fix: bool max bool  works

* test: add a maximum bool test

* test: make sure minimum also works with bool

* fix: something like this? :s

* fix: maybe this?

* fix: how about this? tighter check

* fix: this.

* revert: nvm mul(0.5) and div(2) has the same kernel for backward

* fix: .is_floating_point() xD

* revert: maximum and minimum and add cast

* fix: cover negative const case in test

* fix: use eq because I don't understand clang :D

* WHOOOOPS
This commit is contained in:
geohotstan
2024-01-26 01:26:04 +08:00
committed by GitHub
parent 3628bea910
commit d0e116c6d6
3 changed files with 52 additions and 38 deletions

View File

@@ -397,40 +397,12 @@ class TestAutoCastType(unittest.TestCase):
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@given(strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_broadcast_float(self, default_float):
dtypes.default_float = default_float
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2.3).dtype == dtypes.float16
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2.3).dtype == dtypes.bfloat16
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2.3).dtype == dtypes.float32
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2.3).dtype == dtypes.float64
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]))
def test_broadcast_int(self, default_int):
dtypes.default_int = default_int
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2).dtype == dtypes.default_int
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2).dtype == dtypes.int
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2).dtype == dtypes.int8
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2).dtype == dtypes.uint64
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2).dtype == dtypes.float16
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2).dtype == dtypes.bfloat16
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2).dtype == dtypes.float32
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2).dtype == dtypes.float64
def test_broadcast_bool(self):
if Device.DEFAULT != "WEBGPU":
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
assert (Tensor([0, 1], dtype=dtypes.int) + True).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8) + True).dtype == dtypes.int8
assert (Tensor([0, 1], dtype=dtypes.uint64) + True).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16) + True).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16) + True).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32) + True).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64) + True).dtype == dtypes.float64
@given(strat.sampled_from(core_dtypes))
def test_broadcast_scalar(self, dt):
assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
@@ -466,5 +438,39 @@ class TestAutoCastType(unittest.TestCase):
def test_matmul(self, dt1, dt2):
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_where_no_scalar(self, dt1, dt2):
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
@given(strat.sampled_from(core_dtypes))
def test_where_one_scalar(self, dt):
t = Tensor(2, dtype=dt)
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
self.check_where_alternate_input_other(t, True, dt)
def test_where_two_scalars(self):
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
self.check_where_alternate_input_other(3, True, dtypes.default_int)
self.check_where_alternate_input_other(False, True, dtypes.bool)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_maximum(self, dt1, dt2):
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@given(strat.sampled_from(core_dtypes))
def test_maximum_const(self, dt):
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
if __name__ == '__main__':
unittest.main()