ops_python: gated load support (#3355)

* start uop emu

* tiny_add passes

* more ops

* emulate the whole warp

* test_gemm passes

* metal gemm test pass

* works on big gemm

* works on big gemm

* more tests pass

* touch ups

* fix mypy

* cleanups

* exp2 mypy

* arch is where it belongs

* actually emulate tensor cores

* fix test

* new style

* add gated load support to PYTHON

* out of bounds error message

* cleaner
This commit is contained in:
George Hotz
2024-02-09 11:16:25 +01:00
committed by GitHub
parent c151131d1b
commit 5f93061f67

View File

@@ -31,6 +31,16 @@ def exec_alu(arg, dtype, p):
if arg == BinaryOps.MOD: return p[0]%p[1]
raise NotImplementedError(f"no support for {arg}")
def _load(m, i):
if i<0 or i>=len(m): raise IndexError(f"access out of bounds, size is {len(m)} and access is {i}")
return m[i]
def load(inp, j=0):
if len(inp) == 4:
return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
else:
assert len(inp) == 2, "image loads not supported yet"
return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
class PythonProgram:
def __init__(self, name:str, lib:bytes):
self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
@@ -106,9 +116,9 @@ class PythonProgram:
ul[i] = inp[0]
elif uop is UOps.LOAD:
if dtype.sz > 1:
ul[i] = [[m[x+j] for m,x in zip(inp[0], inp[1])] for j in range(dtype.sz)]
ul[i] = [load(inp, j) for j in range(dtype.sz)]
else:
ul[i] = [m[x] for m,x in zip(inp[0], inp[1])]
ul[i] = load(inp)
elif uop is UOps.PHI:
for j in range(len(inp[0])):
inp[0][j] = inp[1][j]