mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
37 lines
1.5 KiB
Python
37 lines
1.5 KiB
Python
from typing import Callable, List, Tuple
|
|
import itertools
|
|
from tinygrad.lazy import Device
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.ops import DEBUG, GlobalCounters
|
|
|
|
class TinyJit:
|
|
def __init__(self, fxn):
|
|
self.fxn = fxn
|
|
self.cnt = 0
|
|
self.jit_cache : List[Tuple[Callable, List]] = []
|
|
self.ret = None
|
|
self.input_replace = {}
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
|
|
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
|
assert len(input_tensors) != 0, "no inputs to JIT"
|
|
if self.cnt >= 2:
|
|
for a,idx in self.input_replace.items(): a._buf = input_tensors[idx]
|
|
for prg, args in self.jit_cache: prg(*args)
|
|
elif self.cnt == 1:
|
|
GlobalCounters.cache = []
|
|
self.ret = self.fxn(*args, **kwargs)
|
|
self.jit_cache = GlobalCounters.cache
|
|
GlobalCounters.cache = None
|
|
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
|
|
|
# get the inputs for replacement
|
|
for prg, args in self.jit_cache: # pylint: disable=E1133
|
|
self.input_replace.update({a:[k for k,v in input_tensors.items() if v == a._buf][0] for a in args if a._buf in input_tensors.values()})
|
|
assert set(self.input_replace.values()) == set(input_tensors.keys()), "some input tensors not found"
|
|
elif self.cnt == 0:
|
|
self.ret = self.fxn(*args, **kwargs)
|
|
self.cnt += 1
|
|
return self.ret
|