mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
support specifying dtype for Tensor.linear (#9886)
This commit is contained in:
@@ -748,9 +748,20 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_matmul(self, dt1, dt2, acc_dt):
|
||||
t1 = Tensor([0, 1], dtype=dt1)
|
||||
t2 = Tensor([0, 1], dtype=dt2)
|
||||
assert (t1 @ t2).dtype == least_upper_dtype(dt1, dt2)
|
||||
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
assert (t1.matmul(t2, dtype=acc_dt).dtype == acc_dt)
|
||||
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_linear(self, dt1, dt2, dt3, acc_dt):
|
||||
x = Tensor([0, 1], dtype=dt1)
|
||||
w = Tensor([0, 1], dtype=dt2)
|
||||
b = Tensor([0, 1], dtype=dt3)
|
||||
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
|
||||
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
|
||||
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@staticmethod
|
||||
def check_where_alternate_input_other(input_, other, data_type):
|
||||
|
||||
Reference in New Issue
Block a user