mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
axis colors
This commit is contained in:
@@ -72,12 +72,12 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
|
||||
|
||||
# add gpu dims (late). this also handles UNROLL range
|
||||
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+expander, name="expander"))
|
||||
|
||||
# add gpu dims (late). this also handles UNROLL range
|
||||
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_flatten_range+pm_add_buffers_local+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
|
||||
@@ -181,7 +181,6 @@ def apply_tensor_cores(ctx:tuple[dict, Renderer], in0:UOp, in1:UOp, r_range:UOp,
|
||||
srcs = [s.substitute(dict(zip(old_range, new_range))).substitute(dict(zip(ne, tne))) for s in (in0, in1)]
|
||||
srcs = [x.substitute(dict(zip(tne, [ne[i] for i in p]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))]
|
||||
|
||||
# get reduce/upcast axes for the tensor cores
|
||||
ned = dict(zip(tc.base_shape_str(), ne))
|
||||
tc_reduce_axes = tuple([ned[f"r{i}"].arg[0] for i in range(len(tc.get_reduce_axes()))])
|
||||
base_upcast_axes = tuple([(ned[s].arg[0], 2) for s in tc.base_upcast_axes()])
|
||||
@@ -193,9 +192,12 @@ def apply_tensor_cores(ctx:tuple[dict, Renderer], in0:UOp, in1:UOp, r_range:UOp,
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0))+tuple(red_ranges), arg=wmma_arg)
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
||||
return tc_uop.reduce(new_reduce_range, arg=Ops.ADD)
|
||||
ret = tc_uop.reduce(new_reduce_range, arg=Ops.ADD)
|
||||
# confirm the UNROLLs aren't actually used, these need to be broadcast MUL
|
||||
assert all(u not in red_ranges for u in ret.toposort()), "UNROLLs in TC"
|
||||
return ret
|
||||
|
||||
from tinygrad.codegen.opt.postrange import pm_flatten_range
|
||||
|
||||
|
||||
@@ -238,6 +238,7 @@ def no_vectorized_buf(buf:UOp):
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
if idx.dtype.count > 1: return None
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp,
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.opt.kernel import axis_colors
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
@@ -79,7 +80,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
elif len(rngs:=u.ranges):
|
||||
label += f"\n{str(sorted([x.arg[0] for x in rngs]))}"
|
||||
label += f"\n({','.join(sorted([colored(str(x.arg[0]), axis_colors[x.arg[1]]) for x in rngs]))})"
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
|
||||
Reference in New Issue
Block a user