truncate consts early (#6741)

* truncate consts early

* ptx still fails

* Update dtype.py
This commit is contained in:
George Hotz
2024-09-25 16:49:51 +08:00
committed by GitHub
parent e31552e2e0
commit cb22ef379a
6 changed files with 22 additions and 21 deletions

View File

@@ -358,7 +358,7 @@ class TestMultiTensor(unittest.TestCase):
shard_output.backward()
shard_grad = m.conv1.weight.grad.numpy()
# sometimes there is zeros in these grads... why?
np.testing.assert_allclose(grad, shard_grad, atol=3e-6, rtol=3e-6)
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
def test_multi_tensor_jit_param(self):
@TinyJit