mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
remap UOp metadata in graph_rewrite_map [pr] (#10234)
* remap metadata in graph_rewrite_map [pr] * fix * merge loops * UOp.metadata returns Metadata|None * shorter
This commit is contained in:
@@ -239,11 +239,10 @@ class Kernel:
|
||||
@dataclass(frozen=True)
|
||||
class KernelContext:
|
||||
realizes: dict[UOp, None]
|
||||
metadata: dict[UOp, Metadata|None]
|
||||
|
||||
def create_kernel(ctx:KernelContext, x:UOp, b:UOp|None=None):
|
||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.metadata.get(x)) else ()))
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=x.metadata) else ()))
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
|
||||
|
||||
@@ -269,7 +268,7 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s)
|
||||
else:
|
||||
new_srcs.extend(s.src)
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.metadata.get(s)): metadata[m] = None
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=s.metadata): metadata[m] = None
|
||||
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
|
||||
|
||||
# walk back the local graph until we reach a realized parent
|
||||
@@ -491,7 +490,7 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
|
||||
# group into kernels
|
||||
realize_map = group_realizes(tensor_map[big_sink])
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map),
|
||||
bottom_up=True, input_map=tensor_map, name="create_kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
|
||||
@@ -494,7 +494,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg)
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device)
|
||||
@property
|
||||
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
|
||||
def metadata(self) -> Metadata|None: return all_metadata.get(self, None)
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@@ -991,6 +991,7 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, input_map:dict[UOp, UOp]|None=None) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
new_map = {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort())[::-1]}
|
||||
all_metadata.update((v, k.metadata) for k,v in new_map.items() if k.metadata is not None)
|
||||
if input_map is not None:
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
return new_map
|
||||
|
||||
Reference in New Issue
Block a user