diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 0bc6426805..898294e9f4 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -11,16 +11,15 @@ from tinygrad.ops import BinaryOps, TernaryOps from tinygrad.codegen.kernel import LinearizerOptions def _load(m, i): - if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}") + if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}") return m[i] + def load(inp, j=0): - if len(inp) == 4: - return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)] - else: - return [_load(m, x+j) for m,x in zip(inp[0], inp[1])] + if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)] + else: return [_load(m, x+j) for m,x in zip(inp[0], inp[1])] def _store(m, i, v): - if i<0 or i>=len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}") + if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}") m[i] = v class PythonProgram: @@ -87,18 +86,11 @@ class PythonProgram: ul[i] = [x[2-arg[0]] for x in warp] elif uop is UOps.CONST: casted_arg = int(arg) if dtypes.is_int(dtype) else float(arg) - if dtype.count > 1: - ul[i] = [[casted_arg] * warp_size for _ in range(dtype.count)] - else: - ul[i] = [casted_arg] * warp_size + ul[i] = [[casted_arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [casted_arg] * warp_size elif uop is UOps.DEFINE_ACC: - if dtype.count > 1: - ul[i] = [[arg] * warp_size for _ in range(dtype.count)] - else: - ul[i] = [arg] * warp_size + ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size elif uop is UOps.LOOP: - if i not in ul: - ul[i] = [inp[0][0]] * warp_size + if i not in ul: ul[i] = [inp[0][0]] * warp_size else: for j in range(len(ul[i])): ul[i][j] += 1 @@ -107,13 +99,11 @@ class PythonProgram: i = loop_ends[i] + 1 continue elif uop in {UOps.CAST, UOps.BITCAST}: - if dtype.count > 1: - ul[i] = inp + if dtype.count > 1: ul[i] = inp else: assert dtp[0].fmt and dtype.fmt pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt - if uop is UOps.BITCAST: - ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0]))) + if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0]))) else: casted = [float(x) if dtypes.is_float(dtype) else int(x) if dtypes.is_int(dtype) else x for x in inp[0]] overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0 @@ -134,8 +124,7 @@ class PythonProgram: else: ul[i] = load(inp) elif uop is UOps.PHI: - for j in range(len(inp[0])): - inp[0][j] = inp[1][j] + for j in range(len(inp[0])): inp[0][j] = inp[1][j] ul[i] = inp[0] elif uop is UOps.GEP: ul[i] = inp[0][arg] @@ -176,8 +165,7 @@ class PythonProgram: def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads) def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads) ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) - else: - raise NotImplementedError(f"unimplemented tensor core {arg}") + else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif uop is UOps.ALU: assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}" assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"