mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix maximum/where Scalar casting (#3194)
* init * test: added dtype tests for maximum * fix: seperate maximum const and maximum tensors * fix: del useless line * fix: some dtypes * CODE GOLF: we golfing at mar-a-lago golf club tonight boyyyys * fix: add lil helper function * fix: some test refactoring * done * sike: not done yet lol * wtf I missed an assert, am I drunk * yeah idk * fix: line save from redundant check * revert: line save * fix: simplify test_broadcast cuz I'm stumped * change some test name * fix: bool max bool works * test: add a maximum bool test * test: make sure minimum also works with bool * fix: something like this? :s * fix: maybe this? * fix: how about this? tighter check * fix: this. * revert: nvm mul(0.5) and div(2) has the same kernel for backward * fix: .is_floating_point() xD * revert: maximum and minimum and add cast * fix: cover negative const case in test * fix: use eq because I don't understand clang :D * WHOOOOPS
This commit is contained in:
@@ -397,40 +397,12 @@ class TestAutoCastType(unittest.TestCase):
|
||||
# float16 can have larger precision errors
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@given(strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
def test_broadcast_float(self, default_float):
|
||||
dtypes.default_float = default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2.3).dtype == dtypes.float16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2.3).dtype == dtypes.bfloat16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2.3).dtype == dtypes.float32
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2.3).dtype == dtypes.float64
|
||||
|
||||
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]))
|
||||
def test_broadcast_int(self, default_int):
|
||||
dtypes.default_int = default_int
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2).dtype == dtypes.default_int
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2).dtype == dtypes.int
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2).dtype == dtypes.int8
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2).dtype == dtypes.uint64
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2).dtype == dtypes.float16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2).dtype == dtypes.bfloat16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2).dtype == dtypes.float32
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2).dtype == dtypes.float64
|
||||
|
||||
def test_broadcast_bool(self):
|
||||
if Device.DEFAULT != "WEBGPU":
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
|
||||
assert (Tensor([0, 1], dtype=dtypes.int) + True).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8) + True).dtype == dtypes.int8
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64) + True).dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16) + True).dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16) + True).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32) + True).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64) + True).dtype == dtypes.float64
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_broadcast_scalar(self, dt):
|
||||
assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
|
||||
assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt
|
||||
|
||||
def test_sum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
|
||||
@@ -466,5 +438,39 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_matmul(self, dt1, dt2):
|
||||
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
||||
|
||||
@staticmethod
|
||||
def check_where_alternate_input_other(input_, other, data_type):
|
||||
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
|
||||
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_where_no_scalar(self, dt1, dt2):
|
||||
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_where_one_scalar(self, dt):
|
||||
t = Tensor(2, dtype=dt)
|
||||
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
|
||||
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
|
||||
self.check_where_alternate_input_other(t, True, dt)
|
||||
|
||||
def test_where_two_scalars(self):
|
||||
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(3, True, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(False, True, dtypes.bool)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_maximum(self, dt1, dt2):
|
||||
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_maximum_const(self, dt):
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user