put that there

This commit is contained in:
George Hotz
2025-12-03 14:15:13 -08:00
parent e644d59f9f
commit 9cdda8913f

View File

@@ -143,7 +143,7 @@ pm_post_sched_cache = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
])
schedule_cache: dict[bytes, tuple[UOp, dict[UOp, UOp]]] = {}
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}", True)
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
@@ -174,15 +174,16 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
# save in schedule cache
schedule_cache[sched_cache_key] = (big_sink, tensor_map)
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
schedule_cache[sched_cache_key] = (big_sink, tensor_map_sink)
else:
# schedule cache hit
big_sink, tensor_map = sc_ret
del big_sink_cache
big_sink, tensor_map_sink = sc_ret
# replace all the LUNIQUEs with UNIQUEs
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
big_sink = graph_rewrite(big_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for sched cache")
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
tm_src = graph_rewrite(tensor_map_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for tensor map").src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}