mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix maximum/where Scalar casting (#3194)
* init * test: added dtype tests for maximum * fix: seperate maximum const and maximum tensors * fix: del useless line * fix: some dtypes * CODE GOLF: we golfing at mar-a-lago golf club tonight boyyyys * fix: add lil helper function * fix: some test refactoring * done * sike: not done yet lol * wtf I missed an assert, am I drunk * yeah idk * fix: line save from redundant check * revert: line save * fix: simplify test_broadcast cuz I'm stumped * change some test name * fix: bool max bool works * test: add a maximum bool test * test: make sure minimum also works with bool * fix: something like this? :s * fix: maybe this? * fix: how about this? tighter check * fix: this. * revert: nvm mul(0.5) and div(2) has the same kernel for backward * fix: .is_floating_point() xD * revert: maximum and minimum and add cast * fix: cover negative const case in test * fix: use eq because I don't understand clang :D * WHOOOOPS
This commit is contained in:
@@ -397,40 +397,12 @@ class TestAutoCastType(unittest.TestCase):
|
||||
# float16 can have larger precision errors
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@given(strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
def test_broadcast_float(self, default_float):
|
||||
dtypes.default_float = default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2.3).dtype == dtypes.float16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2.3).dtype == dtypes.bfloat16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2.3).dtype == dtypes.float32
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2.3).dtype == dtypes.float64
|
||||
|
||||
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]))
|
||||
def test_broadcast_int(self, default_int):
|
||||
dtypes.default_int = default_int
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2).dtype == dtypes.default_int
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2).dtype == dtypes.int
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2).dtype == dtypes.int8
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2).dtype == dtypes.uint64
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2).dtype == dtypes.float16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2).dtype == dtypes.bfloat16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2).dtype == dtypes.float32
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float64) + 2).dtype == dtypes.float64
|
||||
|
||||
def test_broadcast_bool(self):
|
||||
if Device.DEFAULT != "WEBGPU":
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool) + True).dtype == dtypes.bool
|
||||
assert (Tensor([0, 1], dtype=dtypes.int) + True).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8) + True).dtype == dtypes.int8
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64) + True).dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16) + True).dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16) + True).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32) + True).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64) + True).dtype == dtypes.float64
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_broadcast_scalar(self, dt):
|
||||
assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
|
||||
assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt
|
||||
|
||||
def test_sum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
|
||||
@@ -466,5 +438,39 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_matmul(self, dt1, dt2):
|
||||
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
||||
|
||||
@staticmethod
|
||||
def check_where_alternate_input_other(input_, other, data_type):
|
||||
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
|
||||
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_where_no_scalar(self, dt1, dt2):
|
||||
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_where_one_scalar(self, dt):
|
||||
t = Tensor(2, dtype=dt)
|
||||
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
|
||||
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
|
||||
self.check_where_alternate_input_other(t, True, dt)
|
||||
|
||||
def test_where_two_scalars(self):
|
||||
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(3, True, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(False, True, dtypes.bool)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_maximum(self, dt1, dt2):
|
||||
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_maximum_const(self, dt):
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -262,11 +262,17 @@ class TestOps(unittest.TestCase):
|
||||
def test_maximum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
|
||||
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]])
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=np.array([[1, 0, 3, 4], [1, 2, 3, 0]], dtype=np.int32), forward_only=True)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.])
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True)
|
||||
def test_minimum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
|
||||
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.])
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True)
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True)
|
||||
|
||||
def test_add(self):
|
||||
helper_test_op([(45,68), (45,68)], lambda x,y: x+y)
|
||||
|
||||
@@ -837,11 +837,13 @@ class Tensor:
|
||||
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
# TODO: this implicitly changes dtype with /2
|
||||
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
|
||||
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor:
|
||||
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
|
||||
def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
|
||||
|
||||
def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
|
||||
if isinstance(input_, Tensor): input_, other = input_._broadcasted(other)
|
||||
elif isinstance(other, Tensor): other, input_ = other._broadcasted(input_)
|
||||
x_,y = self._broadcasted(input_, match_dtype=False)
|
||||
x,z = x_._broadcasted(other, match_dtype=False)
|
||||
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
|
||||
|
||||
Reference in New Issue
Block a user