mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
zero len ranges fail (#12974)
* zero len ranges fail
* fix Python backend
* fix llvm
* fix ptx
* yolo fix nir
* this works...
* always store...
* always store...
* Revert "always store..."
This reverts commit 0816cf344d.
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -309,7 +309,7 @@ jobs:
|
||||
key: spec-unit
|
||||
deps: testing_unit
|
||||
- name: Test SPEC=2
|
||||
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 40 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
|
||||
@@ -559,5 +559,12 @@ class TestUOpRender(unittest.TestCase):
|
||||
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
|
||||
self.assertEqual(u.render(), "(0, 1, 2)")
|
||||
|
||||
class TestZeroRange(unittest.TestCase):
|
||||
def test_reduce_variable(self):
|
||||
for i in range(3,-1,-1):
|
||||
v = UOp.variable("i", 0, 5).bind(i)
|
||||
out = Tensor.ones(10, dtype=dtypes.int).contiguous().shrink(((0,v),)).sum()
|
||||
self.assertEqual(out.item(), i)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -107,14 +107,20 @@ base_rewrite = PatternMatcher([
|
||||
|
||||
# range
|
||||
(UPat(Ops.RANGE, name="r"), lambda ctx,r:
|
||||
f" br label %loop_entry_{range_str(r)}\nloop_entry_{range_str(r)}:\n"
|
||||
f" br label %loop_body_{range_str(r)}\nloop_body_{range_str(r)}:\n"
|
||||
f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_latch_{range_str(r)} ]"),
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r")), name="x"), lambda ctx,x,r:
|
||||
f" br label %loop_latch_{range_str(r)}\nloop_latch_{range_str(r)}:\n"
|
||||
f" br label %loop_entry_{range_str(r)}\n"
|
||||
f"loop_entry_{range_str(r)}:\n"
|
||||
f" br label %loop_latch_{range_str(r)}\n"
|
||||
f"loop_latch_{range_str(r)}:\n"
|
||||
f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_footer_{range_str(r)} ]\n"
|
||||
f" {ctx[r]}phi = add {ldt(r.dtype)} {ctx[r]}, 1\n"
|
||||
f" {ctx[x]} = icmp ult {ldt(r.dtype)} {ctx[r]}phi, {ctx[r.src[0]]}\n"
|
||||
f" br i1 {ctx[x]}, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\nloop_exit_{range_str(r)}:"),
|
||||
f" {ctx[r]}cmp = icmp ult {ldt(r.dtype)} {ctx[r]}, {ctx[r.src[0]]}\n"
|
||||
f" br i1 {ctx[r]}cmp, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\n"
|
||||
f"loop_body_{range_str(r)}:"),
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda r:
|
||||
f" br label %loop_footer_{range_str(r)}\n"
|
||||
f"loop_footer_{range_str(r)}:\n"
|
||||
f" br label %loop_latch_{range_str(r)}\n"
|
||||
f"loop_exit_{range_str(r)}:"),
|
||||
|
||||
# if
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
|
||||
import tinygrad.runtime.autogen.mesa as mesa
|
||||
import base64, ctypes, ctypes.util, struct, functools, inspect
|
||||
|
||||
@@ -182,14 +182,17 @@ class NIRRenderer(Renderer):
|
||||
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
|
||||
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
|
||||
elif u.op == Ops.RANGE:
|
||||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{u.arg[0]}".encode()).contents))
|
||||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
mesa.nir_push_loop(self.b)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
|
||||
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
elif u.op == Ops.END:
|
||||
r = u.src[1]
|
||||
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype)), self.r[r.src[0]]),
|
||||
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, r.dtype), lambda: njump(self.b, mesa.nir_jump_break))
|
||||
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
|
||||
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
|
||||
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
|
||||
mesa.nir_pop_loop(self.b, None)
|
||||
else:
|
||||
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
|
||||
|
||||
@@ -119,8 +119,12 @@ string_rewrite = PatternMatcher([
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
# simple
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
|
||||
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [f"mov.u32 {ctx.r[r]}, 0;", "LOOP_" + f"{ctx.r[r][1:]}:"]),
|
||||
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
|
||||
f"mov.u32 {ctx.r[r]}, -1;",
|
||||
f"bra END_{ctx.r[r][1:]};",
|
||||
"LOOP_" + f"{ctx.r[r][1:]}:"]),
|
||||
(UPat(Ops.END, name="x", src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda ctx, x, r: [
|
||||
"END_" + f"{ctx.r[r][1:]}:",
|
||||
ctx.code_for_op[Ops.ADD](ctx.r[r], ctx.r[r], "1", dtypes.int, ctx.types[dtypes.int]),
|
||||
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[r], ctx.r[r.src[0]], dtypes.int, ctx.types[dtypes.int]),
|
||||
f"@{ctx.r[x]} bra LOOP_{ctx.r[r][1:]};"]),
|
||||
|
||||
@@ -52,41 +52,38 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
|
||||
|
||||
class PythonProgram:
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib)
|
||||
self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
st = time.perf_counter()
|
||||
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
||||
warp_size = len(warp)
|
||||
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
|
||||
loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END}
|
||||
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
|
||||
ul: dict[int, Any] = {}
|
||||
dl: dict[int, DType] = {}
|
||||
values: dict[int, Any] = {}
|
||||
pbufs: list[memoryview] = list(bufs)
|
||||
pvals: list[int] = list(vals)
|
||||
i = 0
|
||||
loop_ends: dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
uop, dtype, srcs, arg = self.uops[i]
|
||||
src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops]
|
||||
src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
|
||||
if uop is Ops.END:
|
||||
loop_ends[idp[1]] = i
|
||||
i = idp[1]
|
||||
i = srcs[1]
|
||||
continue
|
||||
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP):
|
||||
# in the python emulator, the warp is always in sync
|
||||
i += 1
|
||||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
dl[i] = dtype
|
||||
if uop is Ops.STORE:
|
||||
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
|
||||
for (m,o,g),v in zip(inp[0], val):
|
||||
if g: _store(m, o+j, v, dtp[1].scalar())
|
||||
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
|
||||
for (m,o,g),v in zip(src_values[0], val):
|
||||
if g: _store(m, o+j, v, src_dtypes[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
if uop is Ops.AFTER: ul[i] = inp[0]
|
||||
if uop is Ops.AFTER: values[i] = src_values[0]
|
||||
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert isinstance(dtype, PtrDType), dtype
|
||||
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
|
||||
@@ -94,72 +91,73 @@ class PythonProgram:
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
|
||||
if uop is Ops.DEFINE_REG:
|
||||
# REGs are per thread
|
||||
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
else:
|
||||
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
|
||||
ul[i] = [buf.cast(storage_fmt)] * warp_size
|
||||
values[i] = [buf.cast(storage_fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
values[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is Ops.SPECIAL:
|
||||
if arg[0] == 'g': ul[i] = [idxs[2-int(arg[-1])]] * warp_size
|
||||
elif arg[0] == 'l': ul[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: ul[i] = [arg] * warp_size
|
||||
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
|
||||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
ret:list = []
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
||||
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
else:
|
||||
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
|
||||
ul[i] = [(m,o,g) for (m,o),g in zip(ret, inp[2] if len(inp) == 3 else [True]*len(ret))] # set the gate last
|
||||
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
|
||||
values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
ul[i] = inp[0]
|
||||
values[i] = src_values[0]
|
||||
elif uop is Ops.RANGE:
|
||||
if i not in ul: ul[i] = [0] * warp_size
|
||||
if i not in values: values[i] = [0] * warp_size
|
||||
else:
|
||||
for j in range(len(ul[i])):
|
||||
ul[i][j] += 1
|
||||
if ul[i][0] == inp[0][0]:
|
||||
del ul[i]
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop is Ops.VECTORIZE: ul[i] = inp
|
||||
for j in range(len(values[i])):
|
||||
values[i][j] += 1
|
||||
if values[i][0] == src_values[0][0]:
|
||||
del values[i]
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop is Ops.VECTORIZE: values[i] = src_values
|
||||
elif uop is Ops.BITCAST:
|
||||
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]])
|
||||
ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
|
||||
ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]]
|
||||
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(src_dtypes[0].scalar()),
|
||||
*[to_storage_scalar(x, src_dtypes[0].scalar()) for x in src_values[0]])
|
||||
values[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
|
||||
values[i] = [from_storage_scalar(x, dtype.scalar()) for x in values[i]]
|
||||
elif uop is Ops.CAST:
|
||||
ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
|
||||
values[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in src_values[0]]
|
||||
elif uop is Ops.LOAD:
|
||||
if dtype.count > 1:
|
||||
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \
|
||||
for j in range(dtype.count)]
|
||||
values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \
|
||||
for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = load(inp, 0, dtype)
|
||||
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
|
||||
values[i] = load(src_values, 0, dtype)
|
||||
elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)]
|
||||
elif uop is Ops.WMMA:
|
||||
first_src_dtype = self.uops[idp[0]][1]
|
||||
first_src_dtype = self.uops[srcs[0]][1]
|
||||
assert isinstance(first_src_dtype, DType) # mypy
|
||||
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
|
||||
wmma_helper = functools.partial(generic_wmma_helper, inp, warp_size)
|
||||
wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
|
||||
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
||||
if device == "METAL":
|
||||
# A (2 elements on 32 threads): row major
|
||||
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
||||
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
||||
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
||||
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif device == "AMD" and threads == 64:
|
||||
def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
|
||||
ul[i] = wmma_helper(64, dims[2], len(inp[0]), len(inp[1]), len(inp[2]), a_elem, b_elem, c_map)
|
||||
elif device == "AMD" and len(inp[0]) == 8: # RDNA4
|
||||
values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
|
||||
elif device == "AMD" and len(src_values[0]) == 8: # RDNA4
|
||||
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
|
||||
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
||||
values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
||||
elif device == "AMD":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, k, row, goff):
|
||||
@@ -168,7 +166,7 @@ class PythonProgram:
|
||||
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif device == "CUDA":
|
||||
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
|
||||
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
|
||||
@@ -176,22 +174,22 @@ class PythonProgram:
|
||||
if dims == (8,16,16):
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,32):
|
||||
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
|
||||
values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.half:
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.float:
|
||||
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif device == "INTEL":
|
||||
@@ -201,17 +199,17 @@ class PythonProgram:
|
||||
def b_elem(x, col, k, goff): return x[k][goff+col]
|
||||
# C, D (8 elements on 8 threads)
|
||||
def c_map(lane, elem): return (lane, elem)
|
||||
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif device == "CPU":
|
||||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(lane, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
|
||||
assert all_same([dtype] + dtp) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
|
||||
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
|
||||
assert i in ul, (uop, dtype, idp, arg)
|
||||
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
|
||||
assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
|
||||
values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
|
||||
assert i in values, (uop, dtype, srcs, arg)
|
||||
i += 1
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
Reference in New Issue
Block a user