mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
minor grouper + viz fixup [pr] (#10217)
* minor grouper + viz fixup [pr] * gitignore mypy_cache * reorder create_kernels * replace with realized * use tensor_map + viz before spec * lint * add that back
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -61,3 +61,4 @@ site/
|
||||
profile_stats
|
||||
*.log
|
||||
target
|
||||
.mypy_cache
|
||||
|
||||
@@ -248,6 +248,20 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
|
||||
|
||||
insert_kernels = merge_views+PatternMatcher([
|
||||
# always give assign/contiguous a kernel
|
||||
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
|
||||
# create a buffer for COPY on the new device
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
|
||||
# remove CONST/BIND/VIEW from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None),
|
||||
# otherwise check the context if we're realizing this UOp
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
|
||||
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
|
||||
])
|
||||
|
||||
def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = dict.fromkeys(x.arg.metadata)
|
||||
@@ -258,21 +272,8 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.metadata.get(s)): metadata[m] = None
|
||||
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
|
||||
|
||||
create_kernels = merge_views+PatternMatcher([
|
||||
# always give assign/contiguous a kernel
|
||||
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
|
||||
# create a buffer for COPY on the new device
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
|
||||
# otherwise check the context if we're realizing this UOp
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
|
||||
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
|
||||
# walk back the local graph until we reach a buffer/assign parent
|
||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
||||
# remove CONST/BIND/VIEW from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None),
|
||||
])
|
||||
# walk back the local graph until we reach a realized parent
|
||||
create_kernels = insert_kernels+PatternMatcher([(UPat(Ops.KERNEL, name="x"), append_to_kernel),])
|
||||
|
||||
# **** swizzler
|
||||
|
||||
@@ -396,7 +397,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
if s.op is Ops.ASSIGN:
|
||||
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
|
||||
parents_rep[s] = s.buf_uop
|
||||
ast = k.arg.ast.substitute(parents_rep)
|
||||
ast = k.arg.ast.substitute(parents_rep, name="replace realized")
|
||||
# push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
# replace buffer with define_global + add load/store last
|
||||
@@ -489,10 +490,9 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
||||
# group into kernels
|
||||
sink = tensor_map[big_sink]
|
||||
realize_map = group_realizes(sink)
|
||||
tensor_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True,
|
||||
input_map=tensor_map, name="create_kernels")
|
||||
realize_map = group_realizes(tensor_map[big_sink])
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
|
||||
bottom_up=True, input_map=tensor_map, name="create_kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
@@ -510,15 +510,15 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
|
||||
# finally, create the AST for kernels
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast")
|
||||
|
||||
# display the final graph
|
||||
sched_sink = tensor_map[big_sink]
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
|
||||
|
||||
# verify Kernels match the spec
|
||||
type_verify(list(sched_sink.toposort()), sched_spec)
|
||||
|
||||
# display the final graph
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
|
||||
|
||||
# map tensors to buffer/assign/const, optionally apply a VIEW on top
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
|
||||
@@ -83,7 +83,7 @@ async function renderDag(graph, additions, recenter=false) {
|
||||
DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4,
|
||||
"long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8}
|
||||
function getBuffer(e) {
|
||||
const [_, size, dtype, device, num] = e.label.split("\n");
|
||||
const [_, size, dtype, num, device] = e.label.split("\n");
|
||||
return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user