axis colors

This commit is contained in:
George Hotz
2025-08-26 12:56:10 -07:00
parent 03fb0c9ad0
commit 4836d6bc60
4 changed files with 11 additions and 7 deletions

View File

@@ -72,12 +72,12 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
# add gpu dims (late). this also handles UNROLL range
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
# expand
ret.append(RewriteStep(sym+expander, name="expander"))
# add gpu dims (late). this also handles UNROLL range
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
# add locals
ret.append(RewriteStep(pm_flatten_range+pm_add_buffers_local+rangeify_codegen, name="add local buffers"))

View File

@@ -181,7 +181,6 @@ def apply_tensor_cores(ctx:tuple[dict, Renderer], in0:UOp, in1:UOp, r_range:UOp,
srcs = [s.substitute(dict(zip(old_range, new_range))).substitute(dict(zip(ne, tne))) for s in (in0, in1)]
srcs = [x.substitute(dict(zip(tne, [ne[i] for i in p]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))]
# get reduce/upcast axes for the tensor cores
ned = dict(zip(tc.base_shape_str(), ne))
tc_reduce_axes = tuple([ned[f"r{i}"].arg[0] for i in range(len(tc.get_reduce_axes()))])
base_upcast_axes = tuple([(ned[s].arg[0], 2) for s in tc.base_upcast_axes()])
@@ -193,9 +192,12 @@ def apply_tensor_cores(ctx:tuple[dict, Renderer], in0:UOp, in1:UOp, r_range:UOp,
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0))+tuple(red_ranges), arg=wmma_arg)
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
return tc_uop.reduce(new_reduce_range, arg=Ops.ADD)
ret = tc_uop.reduce(new_reduce_range, arg=Ops.ADD)
# confirm the UNROLLs aren't actually used, these need to be broadcast MUL
assert all(u not in red_ranges for u in ret.toposort()), "UNROLLs in TC"
return ret
from tinygrad.codegen.opt.postrange import pm_flatten_range

View File

@@ -238,6 +238,7 @@ def no_vectorized_buf(buf:UOp):
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
cnt = cast.dtype.count
if idx.dtype.count > 1: return None
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))

View File

@@ -11,6 +11,7 @@ from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp,
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
from tinygrad.codegen.opt.kernel import axis_colors
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
@@ -79,7 +80,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
label += f"\n{shape_to_str(u.shape)}"
elif len(rngs:=u.ranges):
label += f"\n{str(sorted([x.arg[0] for x in rngs]))}"
label += f"\n({','.join(sorted([colored(str(x.arg[0]), axis_colors[x.arg[1]]) for x in rngs]))})"
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"