minor grouper + viz fixup [pr] (#10217)

* minor grouper + viz fixup [pr]

* gitignore mypy_cache

* reorder create_kernels

* replace with realized

* use tensor_map + viz before spec

* lint

* add that back
This commit is contained in:
qazal
2025-05-08 21:39:44 +03:00
committed by GitHub
parent 0411b09763
commit 40560e77c2
3 changed files with 26 additions and 25 deletions

1
.gitignore vendored
View File

@@ -61,3 +61,4 @@ site/
profile_stats
*.log
target
.mypy_cache

View File

@@ -248,6 +248,20 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
insert_kernels = merge_views+PatternMatcher([
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
# create a buffer for COPY on the new device
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
# remove CONST/BIND/VIEW from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None),
# otherwise check the context if we're realizing this UOp
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
])
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
metadata = dict.fromkeys(x.arg.metadata)
@@ -258,21 +272,8 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.metadata.get(s)): metadata[m] = None
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
create_kernels = merge_views+PatternMatcher([
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
# create a buffer for COPY on the new device
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
# otherwise check the context if we're realizing this UOp
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
# walk back the local graph until we reach a buffer/assign parent
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# remove CONST/BIND/VIEW from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None),
])
# walk back the local graph until we reach a realized parent
create_kernels = insert_kernels+PatternMatcher([(UPat(Ops.KERNEL, name="x"), append_to_kernel),])
# **** swizzler
@@ -396,7 +397,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
if s.op is Ops.ASSIGN:
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
parents_rep[s] = s.buf_uop
ast = k.arg.ast.substitute(parents_rep)
ast = k.arg.ast.substitute(parents_rep, name="replace realized")
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
# replace buffer with define_global + add load/store last
@@ -489,10 +490,9 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
# group into kernels
sink = tensor_map[big_sink]
realize_map = group_realizes(sink)
tensor_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True,
input_map=tensor_map, name="create_kernels")
realize_map = group_realizes(tensor_map[big_sink])
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
bottom_up=True, input_map=tensor_map, name="create_kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
@@ -510,15 +510,15 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
# finally, create the AST for kernels
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast")
# display the final graph
sched_sink = tensor_map[big_sink]
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
# verify Kernels match the spec
type_verify(list(sched_sink.toposort()), sched_spec)
# display the final graph
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
# map tensors to buffer/assign/const, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():

View File

@@ -83,7 +83,7 @@ async function renderDag(graph, additions, recenter=false) {
DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4,
"long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8}
function getBuffer(e) {
const [_, size, dtype, device, num] = e.label.split("\n");
const [_, size, dtype, num, device] = e.label.split("\n");
return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])};
}