From 1f63af467d4d16fd7efe31f2cf6dadd4a5b28fd0 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 19 Dec 2025 13:08:08 -0400 Subject: [PATCH] work --- tinygrad/engine/jit2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/jit2.py b/tinygrad/engine/jit2.py index 3665dedfe9..d37d283c14 100644 --- a/tinygrad/engine/jit2.py +++ b/tinygrad/engine/jit2.py @@ -14,7 +14,7 @@ class TinyJit(Generic[ReturnType]): def __call__(self, *args, **kwargs) -> ReturnType: global schedule_capturing # realize all inputs to the JIT - input_state_dict = get_state_dict((args, kwargs, None)) + input_state_dict = get_state_dict((args, kwargs)) Tensor.realize(*input_state_dict.values()) # capture the schedules that are run @@ -22,7 +22,8 @@ class TinyJit(Generic[ReturnType]): schedule_capturing.append(self) ret = self.fxn(*args, **kwargs) # this gets all tensors referenced in the output - Tensor.realize(*get_state_dict(ret).values()) + output_state_dict = get_state_dict(ret) + Tensor.realize(*output_state_dict.values()) schedule_capturing = [] print(f"JIT schedules:{len(self.schedule_caches)} inp:{len(input_state_dict)}") @@ -31,6 +32,9 @@ class TinyJit(Generic[ReturnType]): for k,v in input_state_dict.items(): print(k, v.uop.base) for input_buffers, sched_cache_key in self.schedule_caches: pre_schedule, combined_sink = schedule_cache[sched_cache_key] + for si in pre_schedule: + print(si.bufs) + for k,v in input_buffers.items(): print(k.pyrender()) print(v.pyrender())