diff --git a/.gitignore b/.gitignore index a1b09eff5d..fd1a734bea 100644 --- a/.gitignore +++ b/.gitignore @@ -61,3 +61,4 @@ site/ profile_stats *.log target +.mypy_cache diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 2b63e68165..15554de584 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -248,6 +248,20 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp): DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} +insert_kernels = merge_views+PatternMatcher([ + # always give assign/contiguous a kernel + (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel), + (UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))), + # create a buffer for COPY on the new device + (UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))), + # remove CONST/BIND/VIEW from SINK + (UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src) + if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None), + # otherwise check the context if we're realizing this UOp + (UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), + lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None), +]) + def append_to_kernel(ctx:KernelContext, x:UOp): new_srcs: list[UOp] = [] metadata = dict.fromkeys(x.arg.metadata) @@ -258,21 +272,8 @@ def append_to_kernel(ctx:KernelContext, x:UOp): if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.metadata.get(s)): metadata[m] = None if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata))) -create_kernels = merge_views+PatternMatcher([ - # always give assign/contiguous a kernel - (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel), - (UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))), - # create a buffer for COPY on the new device - (UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))), - # otherwise check the context if we're realizing this UOp - (UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), - lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None), - # walk back the local graph until we reach a buffer/assign parent - (UPat(Ops.KERNEL, name="x"), append_to_kernel), - # remove CONST/BIND/VIEW from SINK - (UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src) - if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None), -]) +# walk back the local graph until we reach a realized parent +create_kernels = insert_kernels+PatternMatcher([(UPat(Ops.KERNEL, name="x"), append_to_kernel),]) # **** swizzler @@ -396,7 +397,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None: if s.op is Ops.ASSIGN: for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st)) parents_rep[s] = s.buf_uop - ast = k.arg.ast.substitute(parents_rep) + ast = k.arg.ast.substitute(parents_rep, name="replace realized") # push views to edges ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right") # replace buffer with define_global + add load/store last @@ -489,10 +490,9 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph") # group into kernels - sink = tensor_map[big_sink] - realize_map = group_realizes(sink) - tensor_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True, - input_map=tensor_map, name="create_kernels") + realize_map = group_realizes(tensor_map[big_sink]) + tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), + bottom_up=True, input_map=tensor_map, name="create_kernels") # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign kernel_assign: dict[UOp, UOp] = {} @@ -510,15 +510,15 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: # finally, create the AST for kernels tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast") + + # display the final graph sched_sink = tensor_map[big_sink] + if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph") + if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph") # verify Kernels match the spec type_verify(list(sched_sink.toposort()), sched_spec) - # display the final graph - if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph") - if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph") - # map tensors to buffer/assign/const, optionally apply a VIEW on top becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 39aa6136f9..970bca83ac 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -83,7 +83,7 @@ async function renderDag(graph, additions, recenter=false) { DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4, "long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8} function getBuffer(e) { - const [_, size, dtype, device, num] = e.label.split("\n"); + const [_, size, dtype, num, device] = e.label.split("\n"); return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])}; }