diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index b086fb0f21..bab29d43b2 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -31,6 +31,16 @@ def exec_alu(arg, dtype, p): if arg == BinaryOps.MOD: return p[0]%p[1] raise NotImplementedError(f"no support for {arg}") +def _load(m, i): + if i<0 or i>=len(m): raise IndexError(f"access out of bounds, size is {len(m)} and access is {i}") + return m[i] +def load(inp, j=0): + if len(inp) == 4: + return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)] + else: + assert len(inp) == 2, "image loads not supported yet" + return [_load(m, x+j) for m,x in zip(inp[0], inp[1])] + class PythonProgram: def __init__(self, name:str, lib:bytes): self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib) @@ -106,9 +116,9 @@ class PythonProgram: ul[i] = inp[0] elif uop is UOps.LOAD: if dtype.sz > 1: - ul[i] = [[m[x+j] for m,x in zip(inp[0], inp[1])] for j in range(dtype.sz)] + ul[i] = [load(inp, j) for j in range(dtype.sz)] else: - ul[i] = [m[x] for m,x in zip(inp[0], inp[1])] + ul[i] = load(inp) elif uop is UOps.PHI: for j in range(len(inp[0])): inp[0][j] = inp[1][j]