mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
seperate STRIDED and EXPAND
This commit is contained in:
@@ -31,7 +31,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato
|
||||
except Exception:
|
||||
raise Exception(f"{s} failed shape {x.shape}")
|
||||
|
||||
compare("forward pass", ret.cpu().data, out.detach().numpy(), atol=atol, rtol=rtol)
|
||||
compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol)
|
||||
|
||||
torch_fbp, tinygrad_fbp = np.nan, np.nan
|
||||
if not forward_only and not FORWARD_ONLY:
|
||||
@@ -45,7 +45,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato
|
||||
tinygrad_fbp = time.monotonic() - st
|
||||
|
||||
for i, (t, tt) in enumerate(zip(ts, tst)):
|
||||
compare(f"backward pass tensor {i}", tt.cpu().grad.data, t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
|
||||
print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="")
|
||||
|
||||
@@ -104,6 +104,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(2), (2,2)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_matmul(self):
|
||||
helper_test_op([(65), (65,99)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_gemm(self):
|
||||
helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
|
||||
def test_broadcastdot(self):
|
||||
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
def test_multidot(self):
|
||||
|
||||
Reference in New Issue
Block a user