mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
feat: add acc_dtype to einsum (#4571)
This commit is contained in:
@@ -256,6 +256,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
|
||||
helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, acc_dtype=acc_dtype), expected_dtype)
|
||||
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
|
||||
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user