diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 8d7a660c07..fe1949ddfa 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -72,12 +72,12 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC # ** expander (expand_rewrite) ** ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic")) - # add gpu dims (late). this also handles UNROLL range - ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims")) - # expand ret.append(RewriteStep(sym+expander, name="expander")) + # add gpu dims (late). this also handles UNROLL range + ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims")) + # add locals ret.append(RewriteStep(pm_flatten_range+pm_add_buffers_local+rangeify_codegen, name="add local buffers")) diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index a59c4e2b4e..62b23fb7ea 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -181,7 +181,6 @@ def apply_tensor_cores(ctx:tuple[dict, Renderer], in0:UOp, in1:UOp, r_range:UOp, srcs = [s.substitute(dict(zip(old_range, new_range))).substitute(dict(zip(ne, tne))) for s in (in0, in1)] srcs = [x.substitute(dict(zip(tne, [ne[i] for i in p]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))] - # get reduce/upcast axes for the tensor cores ned = dict(zip(tc.base_shape_str(), ne)) tc_reduce_axes = tuple([ned[f"r{i}"].arg[0] for i in range(len(tc.get_reduce_axes()))]) base_upcast_axes = tuple([(ned[s].arg[0], 2) for s in tc.base_upcast_axes()]) @@ -193,9 +192,12 @@ def apply_tensor_cores(ctx:tuple[dict, Renderer], in0:UOp, in1:UOp, r_range:UOp, wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=( UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]), UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]), - UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0))+tuple(red_ranges), arg=wmma_arg) + UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg) tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2]) - return tc_uop.reduce(new_reduce_range, arg=Ops.ADD) + ret = tc_uop.reduce(new_reduce_range, arg=Ops.ADD) + # confirm the UNROLLs aren't actually used, these need to be broadcast MUL + assert all(u not in red_ranges for u in ret.toposort()), "UNROLLs in TC" + return ret from tinygrad.codegen.opt.postrange import pm_flatten_range diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index b1b150beb0..9b8e2a4d16 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -238,6 +238,7 @@ def no_vectorized_buf(buf:UOp): def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): cnt = cast.dtype.count + if idx.dtype.count > 1: return None assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}" return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt)))) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index b0b1f5616c..196f89c8bb 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -11,6 +11,7 @@ from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device from tinygrad.renderer import ProgramSpec from tinygrad.dtype import dtypes +from tinygrad.codegen.opt.kernel import axis_colors uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", @@ -79,7 +80,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None: label += f"\n{shape_to_str(u.shape)}" elif len(rngs:=u.ranges): - label += f"\n{str(sorted([x.arg[0] for x in rngs]))}" + label += f"\n({','.join(sorted([colored(str(x.arg[0]), axis_colors[x.arg[1]]) for x in rngs]))})" except Exception: label += "\n" if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"