mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
support specifying dtype for Tensor.linear (#9886)
This commit is contained in:
@@ -748,9 +748,20 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_matmul(self, dt1, dt2, acc_dt):
|
||||
t1 = Tensor([0, 1], dtype=dt1)
|
||||
t2 = Tensor([0, 1], dtype=dt2)
|
||||
assert (t1 @ t2).dtype == least_upper_dtype(dt1, dt2)
|
||||
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
assert (t1.matmul(t2, dtype=acc_dt).dtype == acc_dt)
|
||||
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_linear(self, dt1, dt2, dt3, acc_dt):
|
||||
x = Tensor([0, 1], dtype=dt1)
|
||||
w = Tensor([0, 1], dtype=dt2)
|
||||
b = Tensor([0, 1], dtype=dt3)
|
||||
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
|
||||
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
|
||||
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@staticmethod
|
||||
def check_where_alternate_input_other(input_, other, data_type):
|
||||
|
||||
@@ -3665,7 +3665,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, weight:Tensor, bias:Tensor|None=None) -> Tensor:
|
||||
def linear(self, weight:Tensor, bias:Tensor|None=None, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a linear transformation to `self` using `weight` and `bias`.
|
||||
|
||||
@@ -3678,6 +3678,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.linear(weight, bias).numpy())
|
||||
```
|
||||
"""
|
||||
if dtype is not None: return self.cast(dtype).linear(weight.cast(dtype), bias.cast(dtype) if bias is not None else bias)
|
||||
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
||||
return x.add(bias) if bias is not None else x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user