diff --git a/test/test_dtype.py b/test/test_dtype.py index 79fbfb8f6d..a0c17090e9 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -748,9 +748,20 @@ class TestAutoCastType(unittest.TestCase): def test_matmul(self, dt1, dt2, acc_dt): t1 = Tensor([0, 1], dtype=dt1) t2 = Tensor([0, 1], dtype=dt2) - assert (t1 @ t2).dtype == least_upper_dtype(dt1, dt2) + self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype)) # if dtype is specified, return in dtype - assert (t1.matmul(t2, dtype=acc_dt).dtype == acc_dt) + self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt) + + @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes)) + def test_linear(self, dt1, dt2, dt3, acc_dt): + x = Tensor([0, 1], dtype=dt1) + w = Tensor([0, 1], dtype=dt2) + b = Tensor([0, 1], dtype=dt3) + self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype)) + self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype)) + # if dtype is specified, return in dtype + self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt) + self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt) @staticmethod def check_where_alternate_input_other(input_, other, data_type): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 83dd917677..d39c0f9c2d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3665,7 +3665,7 @@ class Tensor(SimpleMathTrait): # ***** functional nn ops ***** - def linear(self, weight:Tensor, bias:Tensor|None=None) -> Tensor: + def linear(self, weight:Tensor, bias:Tensor|None=None, dtype:DTypeLike|None=None) -> Tensor: """ Applies a linear transformation to `self` using `weight` and `bias`. @@ -3678,6 +3678,7 @@ class Tensor(SimpleMathTrait): print(t.linear(weight, bias).numpy()) ``` """ + if dtype is not None: return self.cast(dtype).linear(weight.cast(dtype), bias.cast(dtype) if bias is not None else bias) x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) return x.add(bias) if bias is not None else x