mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 06:34:03 -05:00
work
This commit is contained in:
@@ -14,7 +14,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
global schedule_capturing
|
||||
# realize all inputs to the JIT
|
||||
input_state_dict = get_state_dict((args, kwargs, None))
|
||||
input_state_dict = get_state_dict((args, kwargs))
|
||||
Tensor.realize(*input_state_dict.values())
|
||||
|
||||
# capture the schedules that are run
|
||||
@@ -22,7 +22,8 @@ class TinyJit(Generic[ReturnType]):
|
||||
schedule_capturing.append(self)
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
# this gets all tensors referenced in the output
|
||||
Tensor.realize(*get_state_dict(ret).values())
|
||||
output_state_dict = get_state_dict(ret)
|
||||
Tensor.realize(*output_state_dict.values())
|
||||
schedule_capturing = []
|
||||
|
||||
print(f"JIT schedules:{len(self.schedule_caches)} inp:{len(input_state_dict)}")
|
||||
@@ -31,6 +32,9 @@ class TinyJit(Generic[ReturnType]):
|
||||
for k,v in input_state_dict.items(): print(k, v.uop.base)
|
||||
for input_buffers, sched_cache_key in self.schedule_caches:
|
||||
pre_schedule, combined_sink = schedule_cache[sched_cache_key]
|
||||
for si in pre_schedule:
|
||||
print(si.bufs)
|
||||
|
||||
for k,v in input_buffers.items():
|
||||
print(k.pyrender())
|
||||
print(v.pyrender())
|
||||
|
||||
Reference in New Issue
Block a user