mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 00:55:11 -05:00
@@ -11,16 +11,15 @@ from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
|
||||
def _load(m, i):
|
||||
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
return m[i]
|
||||
|
||||
def load(inp, j=0):
|
||||
if len(inp) == 4:
|
||||
return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
|
||||
else:
|
||||
return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
|
||||
if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
|
||||
else: return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
|
||||
|
||||
def _store(m, i, v):
|
||||
if i<0 or i>=len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
||||
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
||||
m[i] = v
|
||||
|
||||
class PythonProgram:
|
||||
@@ -87,18 +86,11 @@ class PythonProgram:
|
||||
ul[i] = [x[2-arg[0]] for x in warp]
|
||||
elif uop is UOps.CONST:
|
||||
casted_arg = int(arg) if dtypes.is_int(dtype) else float(arg)
|
||||
if dtype.count > 1:
|
||||
ul[i] = [[casted_arg] * warp_size for _ in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = [casted_arg] * warp_size
|
||||
ul[i] = [[casted_arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [casted_arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = [arg] * warp_size
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.LOOP:
|
||||
if i not in ul:
|
||||
ul[i] = [inp[0][0]] * warp_size
|
||||
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
||||
else:
|
||||
for j in range(len(ul[i])):
|
||||
ul[i][j] += 1
|
||||
@@ -107,13 +99,11 @@ class PythonProgram:
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if dtype.count > 1:
|
||||
ul[i] = inp
|
||||
if dtype.count > 1: ul[i] = inp
|
||||
else:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
if uop is UOps.BITCAST:
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
else:
|
||||
casted = [float(x) if dtypes.is_float(dtype) else int(x) if dtypes.is_int(dtype) else x for x in inp[0]]
|
||||
overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
|
||||
@@ -134,8 +124,7 @@ class PythonProgram:
|
||||
else:
|
||||
ul[i] = load(inp)
|
||||
elif uop is UOps.PHI:
|
||||
for j in range(len(inp[0])):
|
||||
inp[0][j] = inp[1][j]
|
||||
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
||||
ul[i] = inp[0]
|
||||
elif uop is UOps.GEP:
|
||||
ul[i] = inp[0][arg]
|
||||
@@ -176,8 +165,7 @@ class PythonProgram:
|
||||
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
|
||||
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
else:
|
||||
raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop is UOps.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
|
||||
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
|
||||
|
||||
Reference in New Issue
Block a user