seperate STRIDED and EXPAND

This commit is contained in:
George Hotz
2022-10-30 13:23:58 -07:00
parent 544cb0a069
commit 2f602a92ff
3 changed files with 9 additions and 4 deletions

View File

@@ -31,7 +31,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato
except Exception:
raise Exception(f"{s} failed shape {x.shape}")
compare("forward pass", ret.cpu().data, out.detach().numpy(), atol=atol, rtol=rtol)
compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol)
torch_fbp, tinygrad_fbp = np.nan, np.nan
if not forward_only and not FORWARD_ONLY:
@@ -45,7 +45,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato
tinygrad_fbp = time.monotonic() - st
for i, (t, tt) in enumerate(zip(ts, tst)):
compare(f"backward pass tensor {i}", tt.cpu().grad.data, t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="")
@@ -104,6 +104,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(2), (2,2)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_matmul(self):
helper_test_op([(65), (65,99)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_gemm(self):
helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
def test_broadcastdot(self):
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
def test_multidot(self):