mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
hotfix torch_grad.detach().cpu().numpy() in test_ops (#9268)
This commit is contained in:
@@ -66,7 +66,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
tinygrad_fbp = time.monotonic() - st
|
||||
|
||||
for i, (t, torch_grad) in enumerate(zip(tiny_grads, torch_grads)):
|
||||
compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
|
||||
if not CI:
|
||||
print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \
|
||||
|
||||
Reference in New Issue
Block a user