zero len ranges fail (#12974)

* zero len ranges fail

* fix Python backend

* fix llvm

* fix ptx

* yolo fix nir

* this works...

* always store...

* always store...

* Revert "always store..."

This reverts commit 0816cf344d.
This commit is contained in:
George Hotz
2025-10-28 22:49:55 +08:00
committed by GitHub
parent e936aa7974
commit 5e01cc299b
6 changed files with 93 additions and 75 deletions

View File

@@ -309,7 +309,7 @@ jobs:
key: spec-unit key: spec-unit
deps: testing_unit deps: testing_unit
- name: Test SPEC=2 - name: Test SPEC=2
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 40 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }} run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
fuzzing: fuzzing:
name: Fuzzing name: Fuzzing

View File

@@ -559,5 +559,12 @@ class TestUOpRender(unittest.TestCase):
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2))) u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(), "(0, 1, 2)") self.assertEqual(u.render(), "(0, 1, 2)")
class TestZeroRange(unittest.TestCase):
def test_reduce_variable(self):
for i in range(3,-1,-1):
v = UOp.variable("i", 0, 5).bind(i)
out = Tensor.ones(10, dtype=dtypes.int).contiguous().shrink(((0,v),)).sum()
self.assertEqual(out.item(), i)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)

View File

@@ -107,14 +107,20 @@ base_rewrite = PatternMatcher([
# range # range
(UPat(Ops.RANGE, name="r"), lambda ctx,r: (UPat(Ops.RANGE, name="r"), lambda ctx,r:
f" br label %loop_entry_{range_str(r)}\nloop_entry_{range_str(r)}:\n" f" br label %loop_entry_{range_str(r)}\n"
f" br label %loop_body_{range_str(r)}\nloop_body_{range_str(r)}:\n" f"loop_entry_{range_str(r)}:\n"
f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_latch_{range_str(r)} ]"), f" br label %loop_latch_{range_str(r)}\n"
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r")), name="x"), lambda ctx,x,r: f"loop_latch_{range_str(r)}:\n"
f" br label %loop_latch_{range_str(r)}\nloop_latch_{range_str(r)}:\n" f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_footer_{range_str(r)} ]\n"
f" {ctx[r]}phi = add {ldt(r.dtype)} {ctx[r]}, 1\n" f" {ctx[r]}phi = add {ldt(r.dtype)} {ctx[r]}, 1\n"
f" {ctx[x]} = icmp ult {ldt(r.dtype)} {ctx[r]}phi, {ctx[r.src[0]]}\n" f" {ctx[r]}cmp = icmp ult {ldt(r.dtype)} {ctx[r]}, {ctx[r.src[0]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\nloop_exit_{range_str(r)}:"), f" br i1 {ctx[r]}cmp, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\n"
f"loop_body_{range_str(r)}:"),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda r:
f" br label %loop_footer_{range_str(r)}\n"
f"loop_footer_{range_str(r)}:\n"
f" br label %loop_latch_{range_str(r)}\n"
f"loop_exit_{range_str(r)}:"),
# if # if
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"), (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),

View File

@@ -3,7 +3,7 @@ from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes
from tinygrad.helpers import DEBUG, OSX, unwrap from tinygrad.helpers import DEBUG, OSX, unwrap
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
import tinygrad.runtime.autogen.mesa as mesa import tinygrad.runtime.autogen.mesa as mesa
import base64, ctypes, ctypes.util, struct, functools, inspect import base64, ctypes, ctypes.util, struct, functools, inspect
@@ -182,14 +182,17 @@ class NIRRenderer(Renderer):
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long) self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes() self.b.shader.contents.info.shared_size += u.dtype.nbytes()
elif u.op == Ops.RANGE: elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{u.arg[0]}".encode()).contents)) ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype) nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
mesa.nir_push_loop(self.b) mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype) self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
elif u.op == Ops.END: elif u.op == Ops.END:
r = u.src[1] r = u.src[1]
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype)), self.r[r.src[0]]), next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, r.dtype), lambda: njump(self.b, mesa.nir_jump_break)) # TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
mesa.nir_pop_loop(self.b, None) mesa.nir_pop_loop(self.b, None)
else: else:
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}") if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")

View File

@@ -119,8 +119,12 @@ string_rewrite = PatternMatcher([
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
# simple # simple
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []), (UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [f"mov.u32 {ctx.r[r]}, 0;", "LOOP_" + f"{ctx.r[r][1:]}:"]), (UPat(Ops.RANGE, name="r"), lambda ctx, r: [
f"mov.u32 {ctx.r[r]}, -1;",
f"bra END_{ctx.r[r][1:]};",
"LOOP_" + f"{ctx.r[r][1:]}:"]),
(UPat(Ops.END, name="x", src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda ctx, x, r: [ (UPat(Ops.END, name="x", src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda ctx, x, r: [
"END_" + f"{ctx.r[r][1:]}:",
ctx.code_for_op[Ops.ADD](ctx.r[r], ctx.r[r], "1", dtypes.int, ctx.types[dtypes.int]), ctx.code_for_op[Ops.ADD](ctx.r[r], ctx.r[r], "1", dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[r], ctx.r[r.src[0]], dtypes.int, ctx.types[dtypes.int]), ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[r], ctx.r[r.src[0]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[r][1:]};"]), f"@{ctx.r[x]} bra LOOP_{ctx.r[r][1:]};"]),

View File

@@ -52,41 +52,38 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
class PythonProgram: class PythonProgram:
def __init__(self, name:str, lib:bytes): def __init__(self, name:str, lib:bytes):
self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib) self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
st = time.perf_counter() st = time.perf_counter()
warp = list(itertools.product(*[range(x) for x in local_size[::-1]])) warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
warp_size = len(warp) warp_size = len(warp)
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END}
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]): for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
ul: dict[int, Any] = {} values: dict[int, Any] = {}
dl: dict[int, DType] = {}
pbufs: list[memoryview] = list(bufs) pbufs: list[memoryview] = list(bufs)
pvals: list[int] = list(vals) pvals: list[int] = list(vals)
i = 0 i = 0
loop_ends: dict[int, int] = {}
while i < len(self.uops): while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i] uop, dtype, srcs, arg = self.uops[i]
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE} src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops]
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
if uop is Ops.END: if uop is Ops.END:
loop_ends[idp[1]] = i i = srcs[1]
i = idp[1]
continue continue
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP): if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP):
# in the python emulator, the warp is always in sync # in the python emulator, the warp is always in sync
i += 1 i += 1
continue continue
assert dtype is not None, f"{uop} is missing a dtype" assert dtype is not None, f"{uop} is missing a dtype"
dl[i] = dtype
if uop is Ops.STORE: if uop is Ops.STORE:
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]): for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
for (m,o,g),v in zip(inp[0], val): for (m,o,g),v in zip(src_values[0], val):
if g: _store(m, o+j, v, dtp[1].scalar()) if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1 i += 1
continue continue
if uop is Ops.AFTER: ul[i] = inp[0] if uop is Ops.AFTER: values[i] = src_values[0]
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype assert isinstance(dtype, PtrDType), dtype
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
@@ -94,72 +91,73 @@ class PythonProgram:
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e" if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if uop is Ops.DEFINE_REG: if uop is Ops.DEFINE_REG:
# REGs are per thread # REGs are per thread
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else: else:
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0) buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
ul[i] = [buf.cast(storage_fmt)] * warp_size values[i] = [buf.cast(storage_fmt)] * warp_size
elif uop is Ops.DEFINE_VAR: elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size values[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL: elif uop is Ops.SPECIAL:
if arg[0] == 'g': ul[i] = [idxs[2-int(arg[-1])]] * warp_size if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': ul[i] = [x[2-int(arg[-1])] for x in warp] elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: ul[i] = [arg] * warp_size elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX: elif uop is Ops.INDEX:
ret:list = [] ret:list = []
if isinstance(dtp[0], ImageDType): if isinstance(src_dtypes[0], ImageDType):
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]): for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None)) if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4)) else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else: else:
for m,o in zip(inp[0], inp[1]): ret.append((m,o)) for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
ul[i] = [(m,o,g) for (m,o),g in zip(ret, inp[2] if len(inp) == 3 else [True]*len(ret))] # set the gate last values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
elif uop is Ops.CAST and isinstance(dtype, PtrDType): elif uop is Ops.CAST and isinstance(dtype, PtrDType):
ul[i] = inp[0] values[i] = src_values[0]
elif uop is Ops.RANGE: elif uop is Ops.RANGE:
if i not in ul: ul[i] = [0] * warp_size if i not in values: values[i] = [0] * warp_size
else: else:
for j in range(len(ul[i])): for j in range(len(values[i])):
ul[i][j] += 1 values[i][j] += 1
if ul[i][0] == inp[0][0]: if values[i][0] == src_values[0][0]:
del ul[i] del values[i]
i = loop_ends[i] + 1 i = loop_ends[i] + 1
continue continue
elif uop is Ops.VECTORIZE: ul[i] = inp elif uop is Ops.VECTORIZE: values[i] = src_values
elif uop is Ops.BITCAST: elif uop is Ops.BITCAST:
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]]) packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(src_dtypes[0].scalar()),
ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed)) *[to_storage_scalar(x, src_dtypes[0].scalar()) for x in src_values[0]])
ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]] values[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
values[i] = [from_storage_scalar(x, dtype.scalar()) for x in values[i]]
elif uop is Ops.CAST: elif uop is Ops.CAST:
ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]] values[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in src_values[0]]
elif uop is Ops.LOAD: elif uop is Ops.LOAD:
if dtype.count > 1: if dtype.count > 1:
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \ values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \
for j in range(dtype.count)] for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
else: else:
ul[i] = load(inp, 0, dtype) values[i] = load(src_values, 0, dtype)
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)] elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)]
elif uop is Ops.WMMA: elif uop is Ops.WMMA:
first_src_dtype = self.uops[idp[0]][1] first_src_dtype = self.uops[srcs[0]][1]
assert isinstance(first_src_dtype, DType) # mypy assert isinstance(first_src_dtype, DType) # mypy
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5] dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
wmma_helper = functools.partial(generic_wmma_helper, inp, warp_size) wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
# TODO: refactor these to a shared TensorCoreLayout in kernel.py # TODO: refactor these to a shared TensorCoreLayout in kernel.py
if device == "METAL": if device == "METAL":
# A (2 elements on 32 threads): row major # A (2 elements on 32 threads): row major
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B # (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif device == "AMD" and threads == 64: elif device == "AMD" and threads == 64:
def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row] def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem) def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
ul[i] = wmma_helper(64, dims[2], len(inp[0]), len(inp[1]), len(inp[2]), a_elem, b_elem, c_map) values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
elif device == "AMD" and len(inp[0]) == 8: # RDNA4 elif device == "AMD" and len(src_values[0]) == 8: # RDNA4
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]] def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem) def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map) values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
elif device == "AMD": elif device == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, k, row, goff): def a_elem(x, k, row, goff):
@@ -168,7 +166,7 @@ class PythonProgram:
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15 # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CUDA": elif device == "CUDA":
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8 # (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8) def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
@@ -176,22 +174,22 @@ class PythonProgram:
if dims == (8,16,16): if dims == (8,16,16):
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4] def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4] def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
elif dims == (8,16,32): elif dims == (8,16,32):
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4] def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4] def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map) values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.half: elif dims == (8,16,8) and dtype_in == dtypes.half:
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4] def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4] def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.float: elif dims == (8,16,8) and dtype_in == dtypes.float:
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4] def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4] def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}") else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif device == "INTEL": elif device == "INTEL":
@@ -201,17 +199,17 @@ class PythonProgram:
def b_elem(x, col, k, goff): return x[k][goff+col] def b_elem(x, col, k, goff): return x[k][goff+col]
# C, D (8 elements on 8 threads) # C, D (8 elements on 8 threads)
def c_map(lane, elem): return (lane, elem) def c_map(lane, elem): return (lane, elem)
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map) values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CPU": elif device == "CPU":
def elem(x, col, row, _): return x[col+row][0] # k is always 0 def elem(x, col, row, _): return x[col+row][0] # k is always 0
def c_map(lane, elem): return (elem%16, elem//16) def c_map(lane, elem): return (elem%16, elem//16)
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}") else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop in GroupOp.ALU: elif uop in GroupOp.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}" assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
assert all_same([dtype] + dtp) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}" assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)] values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
assert i in ul, (uop, dtype, idp, arg) assert i in values, (uop, dtype, srcs, arg)
i += 1 i += 1
return time.perf_counter() - st return time.perf_counter() - st