mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
moved extras/jit.py -> tinygrad/jit.py (#599)
* moved extras/jit.py to tinygrad/jit.py * fixed indent * removed tinygrad.helpers.DEBUG from jit.py
This commit is contained in:
@@ -34,7 +34,7 @@ if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
|
||||
from extra.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit
|
||||
@TinyJit
|
||||
def run(x): return model.forward(x).realize()
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def create_onnx_model(keras_model):
|
||||
def compile_onnx_model(onnx_model):
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
|
||||
from extra.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit
|
||||
@TinyJit
|
||||
def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from PIL import Image
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.utils import fetch
|
||||
from extra.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit
|
||||
from models.efficientnet import EfficientNet
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class SpeedyResNet:
|
||||
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of log_softmax
|
||||
def __call__(self, x): return x.sequential(self.net).log_softmax()
|
||||
|
||||
from extra.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit
|
||||
@TinyJit
|
||||
def train_step_jitted(model, optimizer, X, Y):
|
||||
out = model(X)
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_random_input_tensors(input_shapes):
|
||||
np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
|
||||
return inputs, np_inputs
|
||||
|
||||
from extra.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
@TinyJit
|
||||
def model_exec(run_onnx, using_graph, **inputs):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from extra.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "JIT is only for GPU")
|
||||
class TestJit(unittest.TestCase):
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.helpers import colored, getenv, DEBUG
|
||||
from extra.jit import TinyJit
|
||||
from tinygrad.jit import TinyJit
|
||||
METAL = getenv("METAL")
|
||||
try:
|
||||
from tinygrad.runtime.opencl import CL
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
|
||||
import itertools
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
||||
class TinyJit:
|
||||
Reference in New Issue
Block a user