support specifying dtype for Tensor.linear (#9886)

This commit is contained in:
chenyu
2025-04-14 13:55:11 -04:00
committed by GitHub
parent e8a0aee88d
commit ce454793e6
2 changed files with 15 additions and 3 deletions

View File

@@ -748,9 +748,20 @@ class TestAutoCastType(unittest.TestCase):
def test_matmul(self, dt1, dt2, acc_dt):
t1 = Tensor([0, 1], dtype=dt1)
t2 = Tensor([0, 1], dtype=dt2)
assert (t1 @ t2).dtype == least_upper_dtype(dt1, dt2)
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
# if dtype is specified, return in dtype
assert (t1.matmul(t2, dtype=acc_dt).dtype == acc_dt)
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_linear(self, dt1, dt2, dt3, acc_dt):
x = Tensor([0, 1], dtype=dt1)
w = Tensor([0, 1], dtype=dt2)
b = Tensor([0, 1], dtype=dt3)
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
# if dtype is specified, return in dtype
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):