This commit is contained in:
George Hotz
2025-12-19 13:08:08 -04:00
parent 9af2409fc3
commit 1f63af467d

View File

@@ -14,7 +14,7 @@ class TinyJit(Generic[ReturnType]):
def __call__(self, *args, **kwargs) -> ReturnType:
global schedule_capturing
# realize all inputs to the JIT
input_state_dict = get_state_dict((args, kwargs, None))
input_state_dict = get_state_dict((args, kwargs))
Tensor.realize(*input_state_dict.values())
# capture the schedules that are run
@@ -22,7 +22,8 @@ class TinyJit(Generic[ReturnType]):
schedule_capturing.append(self)
ret = self.fxn(*args, **kwargs)
# this gets all tensors referenced in the output
Tensor.realize(*get_state_dict(ret).values())
output_state_dict = get_state_dict(ret)
Tensor.realize(*output_state_dict.values())
schedule_capturing = []
print(f"JIT schedules:{len(self.schedule_caches)} inp:{len(input_state_dict)}")
@@ -31,6 +32,9 @@ class TinyJit(Generic[ReturnType]):
for k,v in input_state_dict.items(): print(k, v.uop.base)
for input_buffers, sched_cache_key in self.schedule_caches:
pre_schedule, combined_sink = schedule_cache[sched_cache_key]
for si in pre_schedule:
print(si.bufs)
for k,v in input_buffers.items():
print(k.pyrender())
print(v.pyrender())