mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Refactor getenv into helpers (#508)
* Refactor getenv into helpers * Remove unused os * Fix default value * Fix more defaults for CI * Fix bracket * Revert changes to openpilot/compile.py * Use getenv from helpers when possible
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
from hexdump import hexdump
|
||||
from macholib import MachO
|
||||
from tinygrad.helpers import getenv
|
||||
def get_macho(fn):
|
||||
# mod to make the header okay
|
||||
# MH_CIGAM_64 is good
|
||||
@@ -124,7 +124,7 @@ for i in range(0, len(f2), 0x300):
|
||||
c1, c2 = f1[i:i+0x300], f2[i:i+0x300]
|
||||
dbg1 = ane.debug(c1, 16)
|
||||
dbg2 = ane.debug(c2, 16)
|
||||
if os.getenv("PRINTALL"):
|
||||
if getenv("PRINTALL"):
|
||||
for k in dbg2:
|
||||
if k in aneregs:
|
||||
rr = aneregs[k] if k in aneregs else (-1,-1,-1)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# 189 Mb embedded RAM, aka 9M 19-bit elements
|
||||
# 2560 MLP blocks, 2 fp24 MULACC each
|
||||
|
||||
import os
|
||||
import functools
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
@@ -183,7 +182,7 @@ binops = {BinaryOps.ADD: riski_add,
|
||||
reduceops = {ReduceOps.SUM: riski_reduce_sum,
|
||||
ReduceOps.MAX: riski_reduce_max}
|
||||
|
||||
SLOW_MATMUL = os.getenv("SLOW_MATMUL", False)
|
||||
SLOW_MATMUL = getenv("SLOW_MATMUL", False)
|
||||
@count
|
||||
def riski_matmul(slow=SLOW_MATMUL):
|
||||
#print("LLL:\n",regfile[Reg.MATMUL_INPUT],"\n",regfile[Reg.MATMUL_WEIGHTS])
|
||||
@@ -205,7 +204,7 @@ def riski_mov(tout, tin):
|
||||
def riski_zero(tout):
|
||||
regfile[tout][:, :] = 0
|
||||
|
||||
load_log = open("/tmp/risk_load_log", "w") if os.getenv("LOAD_LOG") else None
|
||||
load_log = open("/tmp/risk_load_log", "w") if getenv("LOAD_LOG") else None
|
||||
|
||||
@count
|
||||
def riski_load(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ, zero=True, skip_first=False):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import time
|
||||
from tqdm import trange
|
||||
from models.efficientnet import EfficientNet
|
||||
@@ -7,18 +6,19 @@ import tinygrad.nn.optim as optim
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.llops.ops_gpu import CL
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
import gc
|
||||
def tensors_allocated():
|
||||
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
|
||||
|
||||
NUM = int(os.getenv("NUM", 2))
|
||||
BS = int(os.getenv("BS", 8))
|
||||
CNT = int(os.getenv("CNT", 10))
|
||||
BACKWARD = int(os.getenv("BACKWARD", 0))
|
||||
TRAINING = int(os.getenv("TRAINING", 1))
|
||||
ADAM = int(os.getenv("ADAM", 0))
|
||||
CLCACHE = int(os.getenv("CLCACHE", "0"))
|
||||
NUM = getenv("NUM", 2)
|
||||
BS = getenv("BS", 8)
|
||||
CNT = getenv("CNT", 10)
|
||||
BACKWARD = getenv("BACKWARD", 0)
|
||||
TRAINING = getenv("TRAINING", 1)
|
||||
ADAM = getenv("ADAM", 0)
|
||||
CLCACHE = getenv("CLCACHE", 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
|
||||
@@ -9,6 +9,7 @@ import time
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.utils import fetch, get_parameters
|
||||
from models.efficientnet import EfficientNet
|
||||
|
||||
@@ -48,7 +49,7 @@ def infer(model, img):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet(int(os.getenv("NUM", "0")))
|
||||
model = EfficientNet(getenv("NUM", 0))
|
||||
model.load_from_pretrained()
|
||||
|
||||
# category labels
|
||||
|
||||
@@ -7,12 +7,13 @@ sys.path.append(os.getcwd())
|
||||
sys.path.append(os.path.join(os.getcwd(), 'test'))
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.utils import get_parameters
|
||||
import tinygrad.nn.optim as optim
|
||||
from datasets import fetch_mnist
|
||||
from torchvision.utils import make_grid, save_image
|
||||
import torch
|
||||
GPU = os.getenv("GPU") is not None
|
||||
GPU = getenv("GPU")
|
||||
class LinearGen:
|
||||
def __init__(self):
|
||||
lv = 128
|
||||
|
||||
@@ -12,10 +12,11 @@ from extra.utils import get_parameters
|
||||
from datasets import fetch_mnist
|
||||
from extra.training import train, evaluate, sparse_categorical_crossentropy
|
||||
import tinygrad.nn.optim as optim
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.augment import augment_img
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
QUICK = os.getenv("QUICK", None) is not None
|
||||
DEBUG = os.getenv("DEBUG", None) is not None
|
||||
GPU = getenv("GPU")
|
||||
QUICK = getenv("QUICK")
|
||||
DEBUG = getenv("DEBUG")
|
||||
|
||||
class SqueezeExciteBlock2D:
|
||||
def __init__(self, filters):
|
||||
|
||||
@@ -8,6 +8,7 @@ from extra.utils import get_parameters
|
||||
from tqdm import trange
|
||||
from tinygrad.nn import BatchNorm2D
|
||||
import tinygrad.nn.optim as optim
|
||||
from tinygrad.helpers import getenv
|
||||
from datasets import fetch_cifar
|
||||
|
||||
class TinyConvNet:
|
||||
@@ -29,24 +30,24 @@ class TinyConvNet:
|
||||
return x.dot(self.l1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
IMAGENET = os.getenv("IMAGENET") is not None
|
||||
IMAGENET = getenv("IMAGENET")
|
||||
classes = 1000 if IMAGENET else 10
|
||||
|
||||
TINY = os.getenv("TINY") is not None
|
||||
TRANSFER = os.getenv("TRANSFER") is not None
|
||||
TINY = getenv("TINY")
|
||||
TRANSFER = getenv("TRANSFER")
|
||||
if TINY:
|
||||
model = TinyConvNet(classes)
|
||||
elif TRANSFER:
|
||||
model = EfficientNet(int(os.getenv("NUM", "0")), classes, has_se=True)
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
|
||||
model.load_from_pretrained()
|
||||
else:
|
||||
model = EfficientNet(int(os.getenv("NUM", "0")), classes, has_se=False)
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
|
||||
|
||||
parameters = get_parameters(model)
|
||||
print("parameter count", len(parameters))
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
|
||||
BS, steps = int(os.getenv("BS", "64" if TINY else "16")), int(os.getenv("STEPS", "2048"))
|
||||
BS, steps = getenv("BS", 64 if TINY else 16)), getenv("STEPS", 2048))
|
||||
print("training with batch size %d for %d steps" % (BS, steps))
|
||||
|
||||
if IMAGENET:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image
|
||||
@@ -9,9 +8,9 @@ from extra.utils import get_parameters
|
||||
from extra.training import train, evaluate
|
||||
from models.resnet import ResNet
|
||||
from tinygrad.nn.optim import Adam
|
||||
from tinygrad.helpers import getenv
|
||||
from datasets import fetch_mnist
|
||||
|
||||
from tinygrad.nn.optim import Adam
|
||||
|
||||
class ComposeTransforms:
|
||||
def __init__(self, trans):
|
||||
@@ -28,9 +27,8 @@ if __name__ == "__main__":
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
classes = 10
|
||||
|
||||
TRANSFER = os.getenv('TRANSFER') is not None
|
||||
NUM = int(os.getenv('NUM', '18'))
|
||||
model = ResNet(NUM, num_classes=classes)
|
||||
TRANSFER = getenv('TRANSFER')
|
||||
model = ResNet(getenv('NUM', 18), num_classes=classes)
|
||||
if TRANSFER:
|
||||
model.load_from_pretrained()
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
@@ -11,11 +11,11 @@ with tf.io.gfile.GFile(fn, "rb") as f:
|
||||
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from models.vit import ViT
|
||||
import os
|
||||
|
||||
Tensor.training = False
|
||||
if int(os.getenv("LARGE", "0")) == 1:
|
||||
if getenv("LARGE", 0) == 1:
|
||||
m = ViT(embed_dim=768, num_heads=12)
|
||||
else:
|
||||
# tiny
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg
|
||||
# running
|
||||
|
||||
import os
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
import sys
|
||||
import io
|
||||
import time
|
||||
@@ -12,6 +10,8 @@ from tinygrad.tensor import Tensor
|
||||
from extra.utils import fetch, get_parameters
|
||||
from examples.yolo.yolo_nn import Upsample, EmptyLayer, DetectionLayer, LeakyReLU, MaxPool2d
|
||||
from tinygrad.nn import BatchNorm2D, Conv2d
|
||||
from tinygrad.helpers import getenv
|
||||
GPU = getenv("GPU")
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
import os, random, traceback
|
||||
import random, traceback
|
||||
import time
|
||||
import itertools
|
||||
from enum import Enum
|
||||
@@ -9,6 +9,7 @@ from tinygrad.shape import ShapeTracker, View, ZeroView
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel, CL
|
||||
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
|
||||
from tinygrad.ops import DEBUG
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.lib_test_ast import test_ast
|
||||
|
||||
import pickle, dbm
|
||||
@@ -163,7 +164,7 @@ def test_correctness(ast):
|
||||
test_ast(k)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if int(os.getenv("OP", "0")) == 1:
|
||||
if getenv("OP", 0) == 1:
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 3, 3, 3, 4), views=[View((1, 130, 258, 1, 12), (393216, 3072, 12, 12, 1), -3084), ZeroView((1, 128, 256, 1, 12), ((0, 1), (-1, 129), (-1, 257), (0, 1), (0, 12))), View((1, 64, 128, 8, 4, 3, 3, 3, 4), (0, 6192, 24, 0, 0, 3096, 12, 4, 1), 0)]), hostbuf=GPUBuffer(shape=(128, 768, 4), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 3, 3, 3, 4), views=[View((1, 64, 128, 8, 4, 3, 3, 3, 4), (0, 0, 0, 432, 4, 144, 16, 48, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 108, 4), force_create=True))
|
||||
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
|
||||
@@ -179,7 +180,7 @@ if __name__ == "__main__":
|
||||
op7 = LazyOp(BinaryOps.MUL, (buf3,op6,), None)
|
||||
op8 = LazyOp(BinaryOps.SUB, (op3,op7,), None)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op8,), (64, 1024, 4))
|
||||
elif int(os.getenv("OP", "0")) == 2:
|
||||
elif getenv("OP", 0) == 2:
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 1, 1, 3, 3), views=[View((1, 66, 130, 32, 1), (262144, 4096, 32, 1, 1), -4128), ZeroView((1, 64, 128, 32, 1), ((0, 1), (-1, 65), (-1, 129), (0, 32), (0, 1))), View((1, 64, 128, 8, 4, 1, 1, 3, 3), (266240, 4160, 32, 4, 1, 12480, 12480, 4160, 32), 0)]), hostbuf=GPUBuffer(shape=(64, 1024, 4), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 8, 4, 1, 1, 3, 3), views=[View((1, 64, 128, 8, 4, 1, 1, 3, 3), (0, 0, 0, 36, 1, 0, 0, 12, 4), 0)]), hostbuf=GPUBuffer(shape=(8, 9, 4), force_create=True))
|
||||
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
|
||||
@@ -195,7 +196,7 @@ if __name__ == "__main__":
|
||||
op7 = LazyOp(BinaryOps.MUL, (buf3,op6,), None)
|
||||
op8 = LazyOp(BinaryOps.SUB, (op3,op7,), None)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op8,), (64, 1024, 4))
|
||||
elif int(os.getenv("OP", "0")) == 3:
|
||||
elif getenv("OP", 0) == 3:
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 8, 4), views=[View((1, 64, 128, 4, 4, 1, 1, 8, 4), (0, 4096, 32, 0, 0, 0, 0, 4, 1), 0)]), hostbuf=GPUBuffer(shape=(64, 1024, 4), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 8, 4), views=[View((1, 64, 128, 4, 4, 1, 1, 8, 4), (0, 0, 0, 128, 4, 0, 0, 16, 1), 0)]), hostbuf=GPUBuffer(shape=(4, 32, 4), force_create=True))
|
||||
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
|
||||
@@ -203,38 +204,38 @@ if __name__ == "__main__":
|
||||
buf2 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 1, 1), views=[View((1, 64, 128, 4, 4, 1, 1, 1, 1), (0, 0, 0, 4, 1, 1, 1, 1, 1), 0)]), hostbuf=GPUBuffer(shape=(16,), force_create=True))
|
||||
op2 = LazyOp(BinaryOps.ADD, (op1,buf2,), None)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op2,), (64, 512, 4))
|
||||
elif int(os.getenv("REDUCE", "0")):
|
||||
elif getenv("REDUCE", 0):
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(32, 8, 112, 112), views=[View((32, 8, 112, 112), (12544, 401408, 112, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 32, 112, 112), force_create=True))
|
||||
op0 = LazyOp(ReduceOps.SUM, (buf0,), (32, 1, 1, 1))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 1, 1), views=[View((32, 1, 1, 1), (0, 0, 0, 0), 0)]), hostbuf=GPUBuffer(shape=(1,), backing=np.array([9.964923e-06], dtype=np.float32)))
|
||||
op1 = LazyOp(BinaryOps.MUL, (op0,buf1,), None)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op1,), (1, 32, 1, 1))
|
||||
elif int(os.getenv("CONVW", "0")):
|
||||
elif getenv("CONVW", 0):
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(64, 1, 128, 3, 3, 512, 32, 32), views=[View((64, 512, 34, 34), (1024, 65536, 32, 1), -33), ZeroView((64, 512, 32, 32), ((0, 64), (0, 512), (-1, 33), (-1, 33))), View((64, 1, 128, 3, 3, 512, 32, 32), (591872, 591872, 0, 34, 1, 1156, 34, 1), 0)]), hostbuf=GPUBuffer(shape=(512, 64, 32, 32), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(64, 1, 128, 3, 3, 512, 32, 32), views=[View((64, 1, 128, 3, 3, 512, 32, 32), (0, 0, 1024, 0, 0, 131072, 32, 1), 0)]), hostbuf=GPUBuffer(shape=(512, 128, 32, 32), force_create=True))
|
||||
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
|
||||
op1 = LazyOp(ReduceOps.SUM, (op0,), (64, 1, 128, 3, 3, 1, 1, 1))
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op1,), (64, 128, 3, 3))
|
||||
elif int(os.getenv("BC", "0")):
|
||||
elif getenv("BC", 0):
|
||||
# big conv
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(8, 1, 32, 112, 112, 3, 3, 3), views=[View((8, 3, 225, 225), (150528, 50176, 224, 1), 0), ZeroView((8, 3, 224, 224), ((0, 8), (0, 3), (0, 225), (0, 225))), View((8, 1, 32, 112, 112, 3, 3, 3), (151875, 151875, 0, 450, 2, 50625, 225, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 3, 224, 224), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(8, 1, 32, 112, 112, 3, 3, 3), views=[View((8, 1, 32, 112, 112, 3, 3, 3), (0, 0, 27, 0, 0, 9, 3, 1), 0)]), hostbuf=GPUBuffer(shape=(32, 3, 3, 3), force_create=True))
|
||||
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
|
||||
op1 = LazyOp(ReduceOps.SUM, (op0,), (8, 1, 32, 112, 112, 1, 1, 1))
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op1,), (8, 32, 112, 112))
|
||||
elif int(os.getenv("GEMM", "0")):
|
||||
elif getenv("GEMM", 0):
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 1, 512, 512, 1, 1, 1, 512), views=[View((1, 512, 512, 1), (0, 1, 512, 0), 0), View((1, 1, 512, 512, 1, 1, 1, 512), (0, 0, 0, 1, 0, 0, 0, 512), 0)]), hostbuf=GPUBuffer(shape=(512, 512), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 1, 512, 512, 1, 1, 1, 512), views=[View((1, 1, 512, 512, 1, 1, 1, 512), (0, 0, 1, 0, 0, 0, 0, 512), 0)]), hostbuf=GPUBuffer(shape=(512, 512), force_create=True))
|
||||
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
|
||||
op1 = LazyOp(ReduceOps.SUM, (op0,), (1, 1, 512, 512, 1, 1, 1, 1))
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op1,), (512, 512))
|
||||
elif int(os.getenv("FASTCONV", "0")):
|
||||
elif getenv("FASTCONV", 0):
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 32, 32, 32, 64, 3, 3), views=[View((32, 1, 32, 32, 32, 64, 3, 3), (73984, 73984, 0, 34, 1, 1156, 34, 1), 0)]), hostbuf=GPUBuffer(shape=(32, 64, 34, 34), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 32, 32, 32, 64, 3, 3), views=[View((32, 1, 32, 32, 32, 64, 3, 3), (0, 0, 576, 0, 0, 9, 3, 1), 0)]), hostbuf=GPUBuffer(shape=(32, 64, 3, 3), force_create=True))
|
||||
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
|
||||
op1 = LazyOp(ReduceOps.SUM, (op0,), (32, 1, 32, 32, 32, 1, 1, 1))
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op1,), (32, 32, 32, 32))
|
||||
elif int(os.getenv("BROKEN", "0")):
|
||||
elif getenv("BROKEN", 0):
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(64, 1, 1, 1), views=[View((64, 1, 1, 1), (1, 0, 0, 0), 0)]), hostbuf=GPUBuffer(shape=(64,), force_create=True))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(64, 5, 32, 32), views=[View((64, 5, 32, 32), (5120, 1024, 32, 1), 0)]), hostbuf=GPUBuffer(shape=(64, 5, 32, 32), force_create=True))
|
||||
op0 = LazyOp(ReduceOps.SUM, (buf1,), (64, 1, 1, 1))
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.ops import DEBUG
|
||||
from tinygrad.helpers import getenv
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
|
||||
ONNXLIMIT = int(os.getenv("ONNXLIMIT", "-1"))
|
||||
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
||||
|
||||
def get_run_onnx(onnx_model):
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# this can be constructed from a cl_cache or loaded from a thneed file
|
||||
import os
|
||||
import time
|
||||
import struct
|
||||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.llops.ops_gpu import CL, CLProgram
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
|
||||
|
||||
DEBUGCL = int(os.getenv("DEBUGCL", 0))
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
|
||||
class Thneed:
|
||||
def __init__(self, cl_cache=[], inputs={}):
|
||||
@@ -284,7 +283,7 @@ class Thneed:
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(prg.op_estimate)/runtime:9.2f} GFLOPS {prg.options} -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
|
||||
if (DEBUGCL >= 2 and int(os.getenv("PRINT_KERNEL", "-1")) == i) or DEBUGCL >= 3:
|
||||
if (DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3:
|
||||
print(prg.prg)
|
||||
total_runtime += runtime
|
||||
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from extra.utils import get_parameters
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
def sparse_categorical_crossentropy(out, Y):
|
||||
num_classes = out.shape[-1]
|
||||
@@ -18,7 +18,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
||||
Tensor.training = True
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
for i in (t := trange(steps, disable=getenv('CI', False))):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
||||
y = target_transform(Y_train[samp])
|
||||
@@ -48,7 +48,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
|
||||
Tensor.training = False
|
||||
def numpy_eval(Y_test, num_classes):
|
||||
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
|
||||
for i in trange((len(Y_test)-1)//BS+1, disable=os.getenv('CI') is not None):
|
||||
for i in trange((len(Y_test)-1)//BS+1, disable=getenv('CI', False)):
|
||||
x = Tensor(transform(X_test[i*BS:(i+1)*BS]))
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = out.cpu().data
|
||||
|
||||
@@ -2,7 +2,7 @@ from tinygrad.tensor import Tensor
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, getenv
|
||||
|
||||
def fetch(url):
|
||||
if url.startswith("/"):
|
||||
@@ -11,7 +11,7 @@ def fetch(url):
|
||||
return dat
|
||||
import requests, os, hashlib, tempfile
|
||||
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if os.path.isfile(fp) and os.stat(fp).st_size > 0 and os.getenv("NOCACHE", None) is None:
|
||||
if os.path.isfile(fp) and os.stat(fp).st_size > 0 and not getenv("NOCACHE"):
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
|
||||
@@ -7,8 +7,9 @@ if os.getenv("OPT", None) is None:
|
||||
if os.getenv("GPU", None) is None:
|
||||
os.environ['GPU'] = '1'
|
||||
|
||||
ALLOWED_KERNEL_COUNT = int(os.getenv("ALLOWED_KERNEL_COUNT", 0))
|
||||
DEBUGCL = int(os.getenv("DEBUGCL", 0))
|
||||
from tinygrad.helpers import getenv
|
||||
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
|
||||
import onnx
|
||||
import numpy as np
|
||||
@@ -34,7 +35,7 @@ def get_random_input_tensors(input_shapes):
|
||||
"features_buffer": np.random.randn(*input_shapes['features_buffer'])
|
||||
#"initial_state": np.zeros((1, 768))
|
||||
}
|
||||
if int(os.getenv("ZERO_OUT", "0")):
|
||||
if getenv("ZERO_OUT", 0):
|
||||
np_inputs = {k:v*0 for k,v in np_inputs.items()}
|
||||
|
||||
for k,v in np_inputs.items():
|
||||
@@ -105,7 +106,7 @@ def compile(dat, output_fn):
|
||||
from extra.thneed import Thneed
|
||||
t = Thneed(CL.CACHE, {k:inputs[k].lazydata.realized.cl for k in inputs.keys()})
|
||||
CL.CACHE = None
|
||||
if int(os.getenv("OPTWG", "0")):
|
||||
if getenv("OPTWG", 0):
|
||||
t.optimize_local_workgroup()
|
||||
|
||||
# save thneed (before run)
|
||||
@@ -121,7 +122,7 @@ def compile(dat, output_fn):
|
||||
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
|
||||
|
||||
# float32 only (fix this)
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
if FLOAT16 == 0:
|
||||
try:
|
||||
from test.test_onnx import run_onnx_torch
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.lib_test_ast import test_ast
|
||||
|
||||
def compile_and_test_ast(ast):
|
||||
k = CLASTKernel(ast)
|
||||
if int(os.getenv("KOPT", "0")):
|
||||
if getenv("KOPT", 0):
|
||||
from extra.kernel_search import apply_optimization
|
||||
apply_optimization(k, ast, 10, int(os.getenv("KCACHE", "0")))
|
||||
apply_optimization(k, ast, 10, getenv("KCACHE", 0))
|
||||
k.codegen()(*k.bufs)
|
||||
if not int(os.getenv("NOTEST", "0")): test_ast(k)
|
||||
if not getenv("NOTEST", 0): test_ast(k)
|
||||
|
||||
class TestAST(unittest.TestCase):
|
||||
def test_conv_zeroview_ast(self):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
FORWARD_ONLY = bool(int(os.getenv("FORWARD_ONLY", "0")))
|
||||
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3):
|
||||
if tinygrad_fxn is None: tinygrad_fxn = torch_fxn
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -10,15 +10,15 @@ from functools import partial
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.helpers import colored, getenv
|
||||
try:
|
||||
from tinygrad.llops.ops_gpu import CL
|
||||
except ImportError:
|
||||
CL = None
|
||||
|
||||
IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "4,16,64").split(",")]
|
||||
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
|
||||
|
||||
torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else ('cuda' if int(os.getenv("TORCHCUDA", "0")) else 'cpu'))
|
||||
torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
|
||||
|
||||
def colorize_float(x):
|
||||
ret = f"{x:7.2f}x"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import unittest
|
||||
import time
|
||||
import tinygrad.nn.optim as optim
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Device
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.training import train
|
||||
from extra.utils import get_parameters
|
||||
from models.efficientnet import EfficientNet
|
||||
@@ -11,7 +11,7 @@ from models.transformer import Transformer
|
||||
from models.vit import ViT
|
||||
from models.resnet import ResNet18
|
||||
|
||||
BS = int(os.getenv("BS", "2"))
|
||||
BS = getenv("BS", 2)
|
||||
|
||||
def train_one_step(model,X,Y):
|
||||
params = get_parameters(model)
|
||||
|
||||
@@ -4,8 +4,9 @@ import itertools
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
GRAPH = int(os.getenv("GRAPH", "0"))
|
||||
GRAPH = getenv("GRAPH", 0)
|
||||
|
||||
# **** debugging and graphing ****
|
||||
|
||||
@@ -16,7 +17,7 @@ if GRAPH:
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items():
|
||||
print(k, v)
|
||||
if int(os.getenv("PRUNEGRAPH", "0")):
|
||||
if getenv("PRUNEGRAPH", 0):
|
||||
dead_nodes = []
|
||||
for n in G.nodes:
|
||||
# prune movementops and loadops
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict
|
||||
from copy import copy
|
||||
import os, sys, weakref
|
||||
import sys, weakref
|
||||
from tinygrad.helpers import ConvArgs, get_available_llops, prod
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, DEBUG
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
OPT = int(os.getenv("OPT", "2"))
|
||||
NOCONV = int(os.getenv("NOCONV", "0"))
|
||||
IMAGE = int(os.getenv("IMAGE", "0"))
|
||||
OPT = getenv("OPT", 2)
|
||||
NOCONV = getenv("NOCONV", 0)
|
||||
IMAGE = getenv("IMAGE", 0)
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
@@ -125,8 +126,6 @@ def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple(get_weakop(x) if
|
||||
def get_movementroot(root:LazyBuffer) -> LazyBuffer: return get_movementroot(root.op.src[0]) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot(x) if x.optype == MovementOps and x.st.contiguous else x
|
||||
|
||||
LAZY = int(os.getenv("LAZY", "1"))
|
||||
|
||||
class LazyBuffer:
|
||||
lazycache : weakref.WeakValueDictionary[Tuple[str, OpType, LazyOp], LazyBuffer] = weakref.WeakValueDictionary()
|
||||
def __new__(cls, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
|
||||
@@ -150,7 +149,7 @@ class LazyBuffer:
|
||||
# NOTE: op should be read only after construction of LazyBuffer
|
||||
for x in get_buffers(op):
|
||||
x.children.add(self)
|
||||
if not LAZY:
|
||||
if not getenv("LAZY", 1):
|
||||
self.realize()
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Optional, Dict, Union, Set
|
||||
from tinygrad.helpers import prod
|
||||
@@ -8,17 +7,18 @@ from tinygrad.ast import ASTKernel, Token, Types
|
||||
from tinygrad.lazy import IMAGE
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.shape.symbolic import ModNode # this will go away when VALIDHACKS does
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
CUDA = int(os.getenv("CUDA", "0"))
|
||||
CUDA = getenv("CUDA", 0)
|
||||
if not CUDA: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram, CL # NOTE: using CL will not work for the CUDA runtime # noqa: F401
|
||||
else: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
|
||||
|
||||
VALIDHACKS = int(os.getenv("VALIDHACKS", "0")) # TODO: remove the need for this
|
||||
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", "0")) # this is needed as a switch for the tests to pass
|
||||
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
|
||||
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
|
||||
|
||||
KOPT = int(os.getenv("KOPT", "0"))
|
||||
PRINT_AST = os.getenv("PRINT_AST", "0")
|
||||
TEST_AST = int(os.getenv("TEST_AST", "0"))
|
||||
KOPT = getenv("KOPT", 0)
|
||||
PRINT_AST = getenv("PRINT_AST", "0")
|
||||
TEST_AST = getenv("TEST_AST", 0)
|
||||
|
||||
def group_float4(x):
|
||||
assert all(y.typ == Types.FLOAT for y in x) and len(x)%4 == 0
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import os, torch
|
||||
import torch
|
||||
from tinygrad.llops.ops_cpu import CPUBuffer # type: ignore
|
||||
from tinygrad.ops import ProcessingOps, GenericExecAST
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if int(os.getenv("MPS", "0")) else "cpu"))
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
class TorchBuffer(torch.Tensor, GenericExecAST):
|
||||
def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
def strided(x, arg): return x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List
|
||||
import functools, operator
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
DEBUG = getenv("DEBUG", 0)
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN", "RECIPROCAL"])
|
||||
@@ -92,4 +92,4 @@ class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
|
||||
# TODO: creating a new object is making a copy, breaking the thneed compiler
|
||||
def contiguous(self): return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
|
||||
#def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
|
||||
#def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import os, functools, platform
|
||||
import functools, platform
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import DEBUG
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
OSX = platform.system() == "Darwin"
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
|
||||
CLCACHE = int(os.getenv("CLCACHE", "1"))
|
||||
FLOAT16 = int(os.getenv("FLOAT16", "0"))
|
||||
CLCACHE = getenv("CLCACHE", 1)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
|
||||
class CL:
|
||||
CACHE, kernel_count, mem_used, time_sum, ops_sum = None, -1, 0, 0.0, 0.0
|
||||
@@ -21,7 +22,7 @@ class CL:
|
||||
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
if len(devices) == 0: # settle for CPU
|
||||
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
|
||||
CL.cl_ctx = cl.Context(devices=[devices[int(os.getenv("CL_DEVICE", "0"))]])
|
||||
CL.cl_ctx = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
|
||||
if len(devices) > 1 or DEBUG >= 1: print(f"using {CL.cl_ctx.devices}")
|
||||
CL.cl_queue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import functools
|
||||
from typing import Tuple, Union, List, Optional
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
# TODO: fix DEBUG import
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
DEBUG = getenv("DEBUG", 0)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
|
||||
|
||||
Reference in New Issue
Block a user