cleanup ops python (#3908)

i just want to merge lars!
This commit is contained in:
chenyu
2024-03-24 11:36:31 -04:00
committed by GitHub
parent 2c69888654
commit 8c8b57fd5f

View File

@@ -11,16 +11,15 @@ from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.codegen.kernel import LinearizerOptions
def _load(m, i):
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
return m[i]
def load(inp, j=0):
if len(inp) == 4:
return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
else:
return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
else: return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
def _store(m, i, v):
if i<0 or i>=len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
m[i] = v
class PythonProgram:
@@ -87,18 +86,11 @@ class PythonProgram:
ul[i] = [x[2-arg[0]] for x in warp]
elif uop is UOps.CONST:
casted_arg = int(arg) if dtypes.is_int(dtype) else float(arg)
if dtype.count > 1:
ul[i] = [[casted_arg] * warp_size for _ in range(dtype.count)]
else:
ul[i] = [casted_arg] * warp_size
ul[i] = [[casted_arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [casted_arg] * warp_size
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)]
else:
ul[i] = [arg] * warp_size
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
elif uop is UOps.LOOP:
if i not in ul:
ul[i] = [inp[0][0]] * warp_size
if i not in ul: ul[i] = [inp[0][0]] * warp_size
else:
for j in range(len(ul[i])):
ul[i][j] += 1
@@ -107,13 +99,11 @@ class PythonProgram:
i = loop_ends[i] + 1
continue
elif uop in {UOps.CAST, UOps.BITCAST}:
if dtype.count > 1:
ul[i] = inp
if dtype.count > 1: ul[i] = inp
else:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
if uop is UOps.BITCAST:
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
else:
casted = [float(x) if dtypes.is_float(dtype) else int(x) if dtypes.is_int(dtype) else x for x in inp[0]]
overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
@@ -134,8 +124,7 @@ class PythonProgram:
else:
ul[i] = load(inp)
elif uop is UOps.PHI:
for j in range(len(inp[0])):
inp[0][j] = inp[1][j]
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
ul[i] = inp[0]
elif uop is UOps.GEP:
ul[i] = inp[0][arg]
@@ -176,8 +165,7 @@ class PythonProgram:
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
else:
raise NotImplementedError(f"unimplemented tensor core {arg}")
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"