hotfix torch_grad.detach().cpu().numpy() in test_ops (#9268)

This commit is contained in:
chenyu
2025-02-26 12:27:35 -05:00
committed by GitHub
parent 49ca90df75
commit cd822bbe11

View File

@@ -66,7 +66,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
tinygrad_fbp = time.monotonic() - st
for i, (t, torch_grad) in enumerate(zip(tiny_grads, torch_grads)):
compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol)
if not CI:
print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \