feat: add acc_dtype to einsum (#4571)

This commit is contained in:
Filip Brzek
2024-05-13 20:02:07 +02:00
committed by GitHub
parent d97d5a7689
commit f7d08bd454
2 changed files with 3 additions and 2 deletions

View File

@@ -256,6 +256,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, acc_dtype=acc_dtype), expected_dtype)
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)