moved extras/jit.py -> tinygrad/jit.py (#599)

* moved extras/jit.py to tinygrad/jit.py

* fixed indent

* removed tinygrad.helpers.DEBUG from jit.py
This commit is contained in:
voidz
2023-02-25 22:02:33 +05:30
committed by GitHub
parent 7348e9a6c6
commit 94bec40110
8 changed files with 7 additions and 8 deletions

View File

@@ -34,7 +34,7 @@ if __name__ == "__main__":
model = EfficientNet(0)
model.load_from_pretrained()
from extra.jit import TinyJit
from tinygrad.jit import TinyJit
@TinyJit
def run(x): return model.forward(x).realize()

View File

@@ -29,7 +29,7 @@ def create_onnx_model(keras_model):
def compile_onnx_model(onnx_model):
run_onnx = get_run_onnx(onnx_model)
from extra.jit import TinyJit
from tinygrad.jit import TinyJit
@TinyJit
def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize()

View File

@@ -12,7 +12,7 @@ from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.utils import fetch
from extra.jit import TinyJit
from tinygrad.jit import TinyJit
from models.efficientnet import EfficientNet
np.set_printoptions(suppress=True)

View File

@@ -49,7 +49,7 @@ class SpeedyResNet:
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of log_softmax
def __call__(self, x): return x.sequential(self.net).log_softmax()
from extra.jit import TinyJit
from tinygrad.jit import TinyJit
@TinyJit
def train_step_jitted(model, optimizer, X, Y):
out = model(X)

View File

@@ -33,7 +33,7 @@ def get_random_input_tensors(input_shapes):
np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
return inputs, np_inputs
from extra.jit import TinyJit
from tinygrad.jit import TinyJit
@TinyJit
def model_exec(run_onnx, using_graph, **inputs):

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from extra.jit import TinyJit
from tinygrad.jit import TinyJit
@unittest.skipUnless(Device.DEFAULT == "GPU", "JIT is only for GPU")
class TestJit(unittest.TestCase):

View File

@@ -11,7 +11,7 @@ from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, DEBUG
from extra.jit import TinyJit
from tinygrad.jit import TinyJit
METAL = getenv("METAL")
try:
from tinygrad.runtime.opencl import CL

View File

@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
import itertools
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.helpers import DEBUG
from tinygrad.ops import GlobalCounters
class TinyJit: