Python uop emulator (#3327)

* start uop emu

* tiny_add passes

* more ops

* emulate the whole warp

* test_gemm passes

* metal gemm test pass

* works on big gemm

* works on big gemm

* more tests pass

* touch ups

* fix mypy

* cleanups

* exp2 mypy

* arch is where it belongs

* actually emulate tensor cores

* fix test

* new style
This commit is contained in:
George Hotz
2024-02-08 19:24:55 +01:00
committed by GitHub
parent 3ebf7a3e38
commit c32ea95d7d
7 changed files with 216 additions and 42 deletions

View File

@@ -286,6 +286,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True)
def test_tiny_add(self):
helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True)
def test_add(self):
helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add)
helper_test_op([(45,68), (45,68)], lambda x,y: x+y)
@@ -631,6 +634,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_small_gemm(self):
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3)
def test_small_gemm_range(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8),
np.arange(64,128,dtype=np.float32).reshape(8,8)])
def test_small_gemm_eye(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
def test_gemm(self):