Refactor getenv into helpers (#508)

* Refactor getenv into helpers

* Remove unused os

* Fix default value

* Fix more defaults for CI

* Fix bracket

* Revert changes to openpilot/compile.py

* Use getenv from helpers when possible
This commit is contained in:
Jacky Lee
2023-01-31 15:09:09 -08:00
committed by GitHub
parent d91b6711ea
commit 799b3f185a
29 changed files with 102 additions and 101 deletions

View File

@@ -1,8 +1,8 @@
#!/usr/bin/env python3
import os
import sys
from hexdump import hexdump
from macholib import MachO
from tinygrad.helpers import getenv
def get_macho(fn):
# mod to make the header okay
# MH_CIGAM_64 is good
@@ -124,7 +124,7 @@ for i in range(0, len(f2), 0x300):
c1, c2 = f1[i:i+0x300], f2[i:i+0x300]
dbg1 = ane.debug(c1, 16)
dbg2 = ane.debug(c2, 16)
if os.getenv("PRINTALL"):
if getenv("PRINTALL"):
for k in dbg2:
if k in aneregs:
rr = aneregs[k] if k in aneregs else (-1,-1,-1)

View File

@@ -14,7 +14,6 @@
# 189 Mb embedded RAM, aka 9M 19-bit elements
# 2560 MLP blocks, 2 fp24 MULACC each
import os
import functools
import numpy as np
from collections import defaultdict
@@ -183,7 +182,7 @@ binops = {BinaryOps.ADD: riski_add,
reduceops = {ReduceOps.SUM: riski_reduce_sum,
ReduceOps.MAX: riski_reduce_max}
SLOW_MATMUL = os.getenv("SLOW_MATMUL", False)
SLOW_MATMUL = getenv("SLOW_MATMUL", False)
@count
def riski_matmul(slow=SLOW_MATMUL):
#print("LLL:\n",regfile[Reg.MATMUL_INPUT],"\n",regfile[Reg.MATMUL_WEIGHTS])
@@ -205,7 +204,7 @@ def riski_mov(tout, tin):
def riski_zero(tout):
regfile[tout][:, :] = 0
load_log = open("/tmp/risk_load_log", "w") if os.getenv("LOAD_LOG") else None
load_log = open("/tmp/risk_load_log", "w") if getenv("LOAD_LOG") else None
@count
def riski_load(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ, zero=True, skip_first=False):

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python3
import os
import time
from tqdm import trange
from models.efficientnet import EfficientNet
@@ -7,18 +6,19 @@ import tinygrad.nn.optim as optim
from tinygrad.tensor import Tensor
from tinygrad.llops.ops_gpu import CL
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import getenv
import gc
def tensors_allocated():
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
NUM = int(os.getenv("NUM", 2))
BS = int(os.getenv("BS", 8))
CNT = int(os.getenv("CNT", 10))
BACKWARD = int(os.getenv("BACKWARD", 0))
TRAINING = int(os.getenv("TRAINING", 1))
ADAM = int(os.getenv("ADAM", 0))
CLCACHE = int(os.getenv("CLCACHE", "0"))
NUM = getenv("NUM", 2)
BS = getenv("BS", 8)
CNT = getenv("CNT", 10)
BACKWARD = getenv("BACKWARD", 0)
TRAINING = getenv("TRAINING", 1)
ADAM = getenv("ADAM", 0)
CLCACHE = getenv("CLCACHE", 0)
if __name__ == "__main__":
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")

View File

@@ -9,6 +9,7 @@ import time
import numpy as np
np.set_printoptions(suppress=True)
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.utils import fetch, get_parameters
from models.efficientnet import EfficientNet
@@ -48,7 +49,7 @@ def infer(model, img):
if __name__ == "__main__":
# instantiate my net
model = EfficientNet(int(os.getenv("NUM", "0")))
model = EfficientNet(getenv("NUM", 0))
model.load_from_pretrained()
# category labels

View File

@@ -7,12 +7,13 @@ sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'test'))
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.utils import get_parameters
import tinygrad.nn.optim as optim
from datasets import fetch_mnist
from torchvision.utils import make_grid, save_image
import torch
GPU = os.getenv("GPU") is not None
GPU = getenv("GPU")
class LinearGen:
def __init__(self):
lv = 128

View File

@@ -12,10 +12,11 @@ from extra.utils import get_parameters
from datasets import fetch_mnist
from extra.training import train, evaluate, sparse_categorical_crossentropy
import tinygrad.nn.optim as optim
from tinygrad.helpers import getenv
from extra.augment import augment_img
GPU = os.getenv("GPU", None) is not None
QUICK = os.getenv("QUICK", None) is not None
DEBUG = os.getenv("DEBUG", None) is not None
GPU = getenv("GPU")
QUICK = getenv("QUICK")
DEBUG = getenv("DEBUG")
class SqueezeExciteBlock2D:
def __init__(self, filters):

View File

@@ -8,6 +8,7 @@ from extra.utils import get_parameters
from tqdm import trange
from tinygrad.nn import BatchNorm2D
import tinygrad.nn.optim as optim
from tinygrad.helpers import getenv
from datasets import fetch_cifar
class TinyConvNet:
@@ -29,24 +30,24 @@ class TinyConvNet:
return x.dot(self.l1)
if __name__ == "__main__":
IMAGENET = os.getenv("IMAGENET") is not None
IMAGENET = getenv("IMAGENET")
classes = 1000 if IMAGENET else 10
TINY = os.getenv("TINY") is not None
TRANSFER = os.getenv("TRANSFER") is not None
TINY = getenv("TINY")
TRANSFER = getenv("TRANSFER")
if TINY:
model = TinyConvNet(classes)
elif TRANSFER:
model = EfficientNet(int(os.getenv("NUM", "0")), classes, has_se=True)
model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
model.load_from_pretrained()
else:
model = EfficientNet(int(os.getenv("NUM", "0")), classes, has_se=False)
model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
parameters = get_parameters(model)
print("parameter count", len(parameters))
optimizer = optim.Adam(parameters, lr=0.001)
BS, steps = int(os.getenv("BS", "64" if TINY else "16")), int(os.getenv("STEPS", "2048"))
BS, steps = getenv("BS", 64 if TINY else 16)), getenv("STEPS", 2048))
print("training with batch size %d for %d steps" % (BS, steps))
if IMAGENET:

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python3
import os
import numpy as np
import random
from PIL import Image
@@ -9,9 +8,9 @@ from extra.utils import get_parameters
from extra.training import train, evaluate
from models.resnet import ResNet
from tinygrad.nn.optim import Adam
from tinygrad.helpers import getenv
from datasets import fetch_mnist
from tinygrad.nn.optim import Adam
class ComposeTransforms:
def __init__(self, trans):
@@ -28,9 +27,8 @@ if __name__ == "__main__":
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
classes = 10
TRANSFER = os.getenv('TRANSFER') is not None
NUM = int(os.getenv('NUM', '18'))
model = ResNet(NUM, num_classes=classes)
TRANSFER = getenv('TRANSFER')
model = ResNet(getenv('NUM', 18), num_classes=classes)
if TRANSFER:
model.load_from_pretrained()

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python3
import os
import numpy as np
import random

View File

@@ -11,11 +11,11 @@ with tf.io.gfile.GFile(fn, "rb") as f:
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from models.vit import ViT
import os
Tensor.training = False
if int(os.getenv("LARGE", "0")) == 1:
if getenv("LARGE", 0) == 1:
m = ViT(embed_dim=768, num_heads=12)
else:
# tiny

View File

@@ -1,8 +1,6 @@
# https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg
# running
import os
GPU = os.getenv("GPU", None) is not None
import sys
import io
import time
@@ -12,6 +10,8 @@ from tinygrad.tensor import Tensor
from extra.utils import fetch, get_parameters
from examples.yolo.yolo_nn import Upsample, EmptyLayer, DetectionLayer, LeakyReLU, MaxPool2d
from tinygrad.nn import BatchNorm2D, Conv2d
from tinygrad.helpers import getenv
GPU = getenv("GPU")
import cv2
from PIL import Image

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python
import os, random, traceback
import random, traceback
import time
import itertools
from enum import Enum
@@ -9,6 +9,7 @@ from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel, CL
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
from tinygrad.ops import DEBUG
from tinygrad.helpers import getenv
from extra.lib_test_ast import test_ast
import pickle, dbm
@@ -163,7 +164,7 @@ def test_correctness(ast):
test_ast(k)
if __name__ == "__main__":
if int(os.getenv("OP", "0")) == 1:
if getenv("OP", 0) == 1:
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 3, 3, 3, 4), views=[View((1, 130, 258, 1, 12), (393216, 3072, 12, 12, 1), -3084), ZeroView((1, 128, 256, 1, 12), ((0, 1), (-1, 129), (-1, 257), (0, 1), (0, 12))), View((1, 64, 128, 8, 4, 3, 3, 3, 4), (0, 6192, 24, 0, 0, 3096, 12, 4, 1), 0)]), hostbuf=GPUBuffer(shape=(128, 768, 4), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 3, 3, 3, 4), views=[View((1, 64, 128, 8, 4, 3, 3, 3, 4), (0, 0, 0, 432, 4, 144, 16, 48, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 108, 4), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
@@ -179,7 +180,7 @@ if __name__ == "__main__":
op7 = LazyOp(BinaryOps.MUL, (buf3,op6,), None)
op8 = LazyOp(BinaryOps.SUB, (op3,op7,), None)
ast = LazyOp(MovementOps.RESHAPE, (op8,), (64, 1024, 4))
elif int(os.getenv("OP", "0")) == 2:
elif getenv("OP", 0) == 2:
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 1, 1, 3, 3), views=[View((1, 66, 130, 32, 1), (262144, 4096, 32, 1, 1), -4128), ZeroView((1, 64, 128, 32, 1), ((0, 1), (-1, 65), (-1, 129), (0, 32), (0, 1))), View((1, 64, 128, 8, 4, 1, 1, 3, 3), (266240, 4160, 32, 4, 1, 12480, 12480, 4160, 32), 0)]), hostbuf=GPUBuffer(shape=(64, 1024, 4), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 1, 1, 3, 3), views=[View((1, 64, 128, 8, 4, 1, 1, 3, 3), (0, 0, 0, 36, 1, 0, 0, 12, 4), 0)]), hostbuf=GPUBuffer(shape=(8, 9, 4), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
@@ -195,7 +196,7 @@ if __name__ == "__main__":
op7 = LazyOp(BinaryOps.MUL, (buf3,op6,), None)
op8 = LazyOp(BinaryOps.SUB, (op3,op7,), None)
ast = LazyOp(MovementOps.RESHAPE, (op8,), (64, 1024, 4))
elif int(os.getenv("OP", "0")) == 3:
elif getenv("OP", 0) == 3:
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 8, 4), views=[View((1, 64, 128, 4, 4, 1, 1, 8, 4), (0, 4096, 32, 0, 0, 0, 0, 4, 1), 0)]), hostbuf=GPUBuffer(shape=(64, 1024, 4), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 8, 4), views=[View((1, 64, 128, 4, 4, 1, 1, 8, 4), (0, 0, 0, 128, 4, 0, 0, 16, 1), 0)]), hostbuf=GPUBuffer(shape=(4, 32, 4), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
@@ -203,38 +204,38 @@ if __name__ == "__main__":
buf2 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 1, 1), views=[View((1, 64, 128, 4, 4, 1, 1, 1, 1), (0, 0, 0, 4, 1, 1, 1, 1, 1), 0)]), hostbuf=GPUBuffer(shape=(16,), force_create=True))
op2 = LazyOp(BinaryOps.ADD, (op1,buf2,), None)
ast = LazyOp(MovementOps.RESHAPE, (op2,), (64, 512, 4))
elif int(os.getenv("REDUCE", "0")):
elif getenv("REDUCE", 0):
buf0 = GPUBuffer(shape=ShapeTracker(shape=(32, 8, 112, 112), views=[View((32, 8, 112, 112), (12544, 401408, 112, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 32, 112, 112), force_create=True))
op0 = LazyOp(ReduceOps.SUM, (buf0,), (32, 1, 1, 1))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 1, 1), views=[View((32, 1, 1, 1), (0, 0, 0, 0), 0)]), hostbuf=GPUBuffer(shape=(1,), backing=np.array([9.964923e-06], dtype=np.float32)))
op1 = LazyOp(BinaryOps.MUL, (op0,buf1,), None)
ast = LazyOp(MovementOps.RESHAPE, (op1,), (1, 32, 1, 1))
elif int(os.getenv("CONVW", "0")):
elif getenv("CONVW", 0):
buf0 = GPUBuffer(shape=ShapeTracker(shape=(64, 1, 128, 3, 3, 512, 32, 32), views=[View((64, 512, 34, 34), (1024, 65536, 32, 1), -33), ZeroView((64, 512, 32, 32), ((0, 64), (0, 512), (-1, 33), (-1, 33))), View((64, 1, 128, 3, 3, 512, 32, 32), (591872, 591872, 0, 34, 1, 1156, 34, 1), 0)]), hostbuf=GPUBuffer(shape=(512, 64, 32, 32), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(64, 1, 128, 3, 3, 512, 32, 32), views=[View((64, 1, 128, 3, 3, 512, 32, 32), (0, 0, 1024, 0, 0, 131072, 32, 1), 0)]), hostbuf=GPUBuffer(shape=(512, 128, 32, 32), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (64, 1, 128, 3, 3, 1, 1, 1))
ast = LazyOp(MovementOps.RESHAPE, (op1,), (64, 128, 3, 3))
elif int(os.getenv("BC", "0")):
elif getenv("BC", 0):
# big conv
buf0 = GPUBuffer(shape=ShapeTracker(shape=(8, 1, 32, 112, 112, 3, 3, 3), views=[View((8, 3, 225, 225), (150528, 50176, 224, 1), 0), ZeroView((8, 3, 224, 224), ((0, 8), (0, 3), (0, 225), (0, 225))), View((8, 1, 32, 112, 112, 3, 3, 3), (151875, 151875, 0, 450, 2, 50625, 225, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 3, 224, 224), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(8, 1, 32, 112, 112, 3, 3, 3), views=[View((8, 1, 32, 112, 112, 3, 3, 3), (0, 0, 27, 0, 0, 9, 3, 1), 0)]), hostbuf=GPUBuffer(shape=(32, 3, 3, 3), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (8, 1, 32, 112, 112, 1, 1, 1))
ast = LazyOp(MovementOps.RESHAPE, (op1,), (8, 32, 112, 112))
elif int(os.getenv("GEMM", "0")):
elif getenv("GEMM", 0):
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 1, 512, 512, 1, 1, 1, 512), views=[View((1, 512, 512, 1), (0, 1, 512, 0), 0), View((1, 1, 512, 512, 1, 1, 1, 512), (0, 0, 0, 1, 0, 0, 0, 512), 0)]), hostbuf=GPUBuffer(shape=(512, 512), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 1, 512, 512, 1, 1, 1, 512), views=[View((1, 1, 512, 512, 1, 1, 1, 512), (0, 0, 1, 0, 0, 0, 0, 512), 0)]), hostbuf=GPUBuffer(shape=(512, 512), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (1, 1, 512, 512, 1, 1, 1, 1))
ast = LazyOp(MovementOps.RESHAPE, (op1,), (512, 512))
elif int(os.getenv("FASTCONV", "0")):
elif getenv("FASTCONV", 0):
buf0 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 32, 32, 32, 64, 3, 3), views=[View((32, 1, 32, 32, 32, 64, 3, 3), (73984, 73984, 0, 34, 1, 1156, 34, 1), 0)]), hostbuf=GPUBuffer(shape=(32, 64, 34, 34), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 32, 32, 32, 64, 3, 3), views=[View((32, 1, 32, 32, 32, 64, 3, 3), (0, 0, 576, 0, 0, 9, 3, 1), 0)]), hostbuf=GPUBuffer(shape=(32, 64, 3, 3), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (32, 1, 32, 32, 32, 1, 1, 1))
ast = LazyOp(MovementOps.RESHAPE, (op1,), (32, 32, 32, 32))
elif int(os.getenv("BROKEN", "0")):
elif getenv("BROKEN", 0):
buf0 = GPUBuffer(shape=ShapeTracker(shape=(64, 1, 1, 1), views=[View((64, 1, 1, 1), (1, 0, 0, 0), 0)]), hostbuf=GPUBuffer(shape=(64,), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(64, 5, 32, 32), views=[View((64, 5, 32, 32), (5120, 1024, 32, 1), 0)]), hostbuf=GPUBuffer(shape=(64, 5, 32, 32), force_create=True))
op0 = LazyOp(ReduceOps.SUM, (buf1,), (64, 1, 1, 1))

View File

@@ -1,11 +1,11 @@
import os
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from tinygrad.ops import DEBUG
from tinygrad.helpers import getenv
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
ONNXLIMIT = int(os.getenv("ONNXLIMIT", "-1"))
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model):
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)

View File

@@ -1,18 +1,17 @@
# this can be constructed from a cl_cache or loaded from a thneed file
import os
import time
import struct
import json
import traceback
import numpy as np
from tinygrad.llops.ops_gpu import CL, CLProgram
from tinygrad.helpers import prod
from tinygrad.helpers import prod, getenv
from collections import defaultdict
import pyopencl as cl
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
DEBUGCL = int(os.getenv("DEBUGCL", 0))
FLOAT16 = int(os.getenv("FLOAT16", 0))
DEBUGCL = getenv("DEBUGCL", 0)
FLOAT16 = getenv("FLOAT16", 0)
class Thneed:
def __init__(self, cl_cache=[], inputs={}):
@@ -284,7 +283,7 @@ class Thneed:
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(prg.op_estimate)/runtime:9.2f} GFLOPS {prg.options} -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
if (DEBUGCL >= 2 and int(os.getenv("PRINT_KERNEL", "-1")) == i) or DEBUGCL >= 3:
if (DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3:
print(prg.prg)
total_runtime += runtime
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")

View File

@@ -1,8 +1,8 @@
import os
import numpy as np
from tqdm import trange
from extra.utils import get_parameters
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import getenv
def sparse_categorical_crossentropy(out, Y):
num_classes = out.shape[-1]
@@ -18,7 +18,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
Tensor.training = True
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
for i in (t := trange(steps, disable=getenv('CI', False))):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = target_transform(Y_train[samp])
@@ -48,7 +48,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
Tensor.training = False
def numpy_eval(Y_test, num_classes):
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
for i in trange((len(Y_test)-1)//BS+1, disable=os.getenv('CI') is not None):
for i in trange((len(Y_test)-1)//BS+1, disable=getenv('CI', False)):
x = Tensor(transform(X_test[i*BS:(i+1)*BS]))
out = model.forward(x) if hasattr(model, 'forward') else model(x)
Y_test_preds_out[i*BS:(i+1)*BS] = out.cpu().data

View File

@@ -2,7 +2,7 @@ from tinygrad.tensor import Tensor
import pickle
from tqdm import tqdm
import numpy as np
from tinygrad.helpers import prod
from tinygrad.helpers import prod, getenv
def fetch(url):
if url.startswith("/"):
@@ -11,7 +11,7 @@ def fetch(url):
return dat
import requests, os, hashlib, tempfile
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
if os.path.isfile(fp) and os.stat(fp).st_size > 0 and os.getenv("NOCACHE", None) is None:
if os.path.isfile(fp) and os.stat(fp).st_size > 0 and not getenv("NOCACHE"):
with open(fp, "rb") as f:
dat = f.read()
else:

View File

@@ -7,8 +7,9 @@ if os.getenv("OPT", None) is None:
if os.getenv("GPU", None) is None:
os.environ['GPU'] = '1'
ALLOWED_KERNEL_COUNT = int(os.getenv("ALLOWED_KERNEL_COUNT", 0))
DEBUGCL = int(os.getenv("DEBUGCL", 0))
from tinygrad.helpers import getenv
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
DEBUGCL = getenv("DEBUGCL", 0)
import onnx
import numpy as np
@@ -34,7 +35,7 @@ def get_random_input_tensors(input_shapes):
"features_buffer": np.random.randn(*input_shapes['features_buffer'])
#"initial_state": np.zeros((1, 768))
}
if int(os.getenv("ZERO_OUT", "0")):
if getenv("ZERO_OUT", 0):
np_inputs = {k:v*0 for k,v in np_inputs.items()}
for k,v in np_inputs.items():
@@ -105,7 +106,7 @@ def compile(dat, output_fn):
from extra.thneed import Thneed
t = Thneed(CL.CACHE, {k:inputs[k].lazydata.realized.cl for k in inputs.keys()})
CL.CACHE = None
if int(os.getenv("OPTWG", "0")):
if getenv("OPTWG", 0):
t.optimize_local_workgroup()
# save thneed (before run)
@@ -121,7 +122,7 @@ def compile(dat, output_fn):
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
# float32 only (fix this)
FLOAT16 = int(os.getenv("FLOAT16", 0))
FLOAT16 = getenv("FLOAT16", 0)
if FLOAT16 == 0:
try:
from test.test_onnx import run_onnx_torch

View File

@@ -1,19 +1,19 @@
#!/usr/bin/env python
import os
import unittest
import numpy as np
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel
from tinygrad.helpers import getenv
from extra.lib_test_ast import test_ast
def compile_and_test_ast(ast):
k = CLASTKernel(ast)
if int(os.getenv("KOPT", "0")):
if getenv("KOPT", 0):
from extra.kernel_search import apply_optimization
apply_optimization(k, ast, 10, int(os.getenv("KCACHE", "0")))
apply_optimization(k, ast, 10, getenv("KCACHE", 0))
k.codegen()(*k.bufs)
if not int(os.getenv("NOTEST", "0")): test_ast(k)
if not getenv("NOTEST", 0): test_ast(k)
class TestAST(unittest.TestCase):
def test_conv_zeroview_ast(self):

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
import os
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device

View File

@@ -1,11 +1,11 @@
import os
import torch
import time
import numpy as np
import unittest
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import getenv
FORWARD_ONLY = bool(int(os.getenv("FORWARD_ONLY", "0")))
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3):
if tinygrad_fxn is None: tinygrad_fxn = torch_fxn
torch.manual_seed(0)

View File

@@ -10,15 +10,15 @@ from functools import partial
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored
from tinygrad.helpers import colored, getenv
try:
from tinygrad.llops.ops_gpu import CL
except ImportError:
CL = None
IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "4,16,64").split(",")]
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else ('cuda' if int(os.getenv("TORCHCUDA", "0")) else 'cpu'))
torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
def colorize_float(x):
ret = f"{x:7.2f}x"

View File

@@ -1,9 +1,9 @@
import os
import unittest
import time
import tinygrad.nn.optim as optim
import numpy as np
from tinygrad.tensor import Device
from tinygrad.helpers import getenv
from extra.training import train
from extra.utils import get_parameters
from models.efficientnet import EfficientNet
@@ -11,7 +11,7 @@ from models.transformer import Transformer
from models.vit import ViT
from models.resnet import ResNet18
BS = int(os.getenv("BS", "2"))
BS = getenv("BS", 2)
def train_one_step(model,X,Y):
params = get_parameters(model)

View File

@@ -4,8 +4,9 @@ import itertools
from collections import defaultdict
from typing import Dict, List
from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType
from tinygrad.helpers import getenv
GRAPH = int(os.getenv("GRAPH", "0"))
GRAPH = getenv("GRAPH", 0)
# **** debugging and graphing ****
@@ -16,7 +17,7 @@ if GRAPH:
def save_graph_exit():
for k,v in cnts.items():
print(k, v)
if int(os.getenv("PRUNEGRAPH", "0")):
if getenv("PRUNEGRAPH", 0):
dead_nodes = []
for n in G.nodes:
# prune movementops and loadops

View File

@@ -1,18 +1,19 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict
from copy import copy
import os, sys, weakref
import sys, weakref
from tinygrad.helpers import ConvArgs, get_available_llops, prod
from tinygrad.shape import ShapeTracker
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, DEBUG
from tinygrad.graph import log_op
from tinygrad.helpers import getenv
# lazy can recurse a lot
sys.setrecursionlimit(10000)
OPT = int(os.getenv("OPT", "2"))
NOCONV = int(os.getenv("NOCONV", "0"))
IMAGE = int(os.getenv("IMAGE", "0"))
OPT = getenv("OPT", 2)
NOCONV = getenv("NOCONV", 0)
IMAGE = getenv("IMAGE", 0)
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
@@ -125,8 +126,6 @@ def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple(get_weakop(x) if
def get_movementroot(root:LazyBuffer) -> LazyBuffer: return get_movementroot(root.op.src[0]) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot(x) if x.optype == MovementOps and x.st.contiguous else x
LAZY = int(os.getenv("LAZY", "1"))
class LazyBuffer:
lazycache : weakref.WeakValueDictionary[Tuple[str, OpType, LazyOp], LazyBuffer] = weakref.WeakValueDictionary()
def __new__(cls, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
@@ -150,7 +149,7 @@ class LazyBuffer:
# NOTE: op should be read only after construction of LazyBuffer
for x in get_buffers(op):
x.children.add(self)
if not LAZY:
if not getenv("LAZY", 1):
self.realize()
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import os
import numpy as np
from typing import List, Tuple, Optional, Dict, Union, Set
from tinygrad.helpers import prod
@@ -8,17 +7,18 @@ from tinygrad.ast import ASTKernel, Token, Types
from tinygrad.lazy import IMAGE
from tinygrad.shape import ShapeTracker
from tinygrad.shape.symbolic import ModNode # this will go away when VALIDHACKS does
from tinygrad.helpers import getenv
CUDA = int(os.getenv("CUDA", "0"))
CUDA = getenv("CUDA", 0)
if not CUDA: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram, CL # NOTE: using CL will not work for the CUDA runtime # noqa: F401
else: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
VALIDHACKS = int(os.getenv("VALIDHACKS", "0")) # TODO: remove the need for this
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", "0")) # this is needed as a switch for the tests to pass
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
KOPT = int(os.getenv("KOPT", "0"))
PRINT_AST = os.getenv("PRINT_AST", "0")
TEST_AST = int(os.getenv("TEST_AST", "0"))
KOPT = getenv("KOPT", 0)
PRINT_AST = getenv("PRINT_AST", "0")
TEST_AST = getenv("TEST_AST", 0)
def group_float4(x):
assert all(y.typ == Types.FLOAT for y in x) and len(x)%4 == 0

View File

@@ -1,8 +1,9 @@
import os, torch
import torch
from tinygrad.llops.ops_cpu import CPUBuffer # type: ignore
from tinygrad.ops import ProcessingOps, GenericExecAST
from tinygrad.helpers import getenv
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if int(os.getenv("MPS", "0")) else "cpu"))
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
class TorchBuffer(torch.Tensor, GenericExecAST):
def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
def strided(x, arg): return x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import os
import numpy as np
from enum import Enum
from typing import Union, Type, NamedTuple, Tuple, Any, List
import functools, operator
from tinygrad.helpers import prod
from tinygrad.shape import ShapeTracker
from tinygrad.helpers import getenv
DEBUG = int(os.getenv("DEBUG", "0"))
DEBUG = getenv("DEBUG", 0)
# these are the llops your accelerator must implement, along with toCpu
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN", "RECIPROCAL"])
@@ -92,4 +92,4 @@ class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method
# TODO: creating a new object is making a copy, breaking the thneed compiler
def contiguous(self): return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
#def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
#def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP)

View File

@@ -1,15 +1,16 @@
import os, functools, platform
import functools, platform
import numpy as np
import pyopencl as cl # type: ignore
from typing import Dict, Optional, Tuple, List
from collections import defaultdict
from tinygrad.ops import DEBUG
from tinygrad.helpers import getenv
OSX = platform.system() == "Darwin"
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
CLCACHE = int(os.getenv("CLCACHE", "1"))
FLOAT16 = int(os.getenv("FLOAT16", "0"))
CLCACHE = getenv("CLCACHE", 1)
FLOAT16 = getenv("FLOAT16", 0)
class CL:
CACHE, kernel_count, mem_used, time_sum, ops_sum = None, -1, 0, 0.0, 0.0
@@ -21,7 +22,7 @@ class CL:
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: # settle for CPU
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
CL.cl_ctx = cl.Context(devices=[devices[int(os.getenv("CL_DEVICE", "0"))]])
CL.cl_ctx = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
if len(devices) > 1 or DEBUG >= 1: print(f"using {CL.cl_ctx.devices}")
CL.cl_queue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue

View File

@@ -1,13 +1,12 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import os
import functools
from typing import Tuple, Union, List, Optional
from tinygrad.helpers import prod
from tinygrad.helpers import prod, getenv
from tinygrad.shape.symbolic import Variable
# TODO: fix DEBUG import
DEBUG = int(os.getenv("DEBUG", "0"))
DEBUG = getenv("DEBUG", 0)
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]: