Windows & Ubuntu CLANG CI support (#1011)

* matrix strategy

* push env to GITHUB_ENV

* use printf instead of echo

* use temp helper function for cross os paths

* use path join

* switched to using temp helper function

* skip test on windows due to memory limit

* small fix

* removed semi

* touchups

* clean up

* seperate tests

* test changes to test_utils on windows

* small refactor

* more cleanups

* undo helpers change

* only skip if in CI and WINDOWS
This commit is contained in:
Diogo
2023-06-19 12:33:24 -04:00
committed by GitHub
parent 0d4c4f4e9e
commit 57d3aa76a5
6 changed files with 82 additions and 58 deletions

View File

@@ -95,6 +95,27 @@ jobs:
run: pip install -e '.[llvm,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
run: ENABLE_METHOD_CACHE=1 LLVM=1 python -m pytest -s -v -n=auto test/
testclang:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
runs-on: ${{ matrix.os }}
name: CLANG Tests ${{ matrix.os }} (w method cache)
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Set env
run: printf "CI=1\nCLANG=1\nENABLE_METHOD_CACHE=1" >> $GITHUB_ENV
- name: Run Pytest
run: python -m pytest -s -v -n=auto test/
testtorch:
name: Torch Tests

View File

@@ -1,7 +1,7 @@
import pickle
import numpy as np
from tqdm import tqdm
import tempfile, platform
import tempfile, platform, os
from collections import defaultdict
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from tinygrad.ops import GlobalCounters
@@ -9,13 +9,16 @@ from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
from tinygrad.shape.shapetracker import strides_for_shape
OSX = platform.system() == "Darwin"
WINDOWS = platform.system() == "Windows"
def temp(x:str) -> str: return os.path.join(tempfile.gettempdir(), x)
def fetch(url):
if url.startswith("/"):
with open(url, "rb") as f:
return f.read()
import os, hashlib, tempfile
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
import hashlib
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
with open(fp, "rb") as f:
return f.read()
@@ -24,8 +27,8 @@ def fetch_as_file(url):
if url.startswith("/"):
with open(url, "rb") as f:
return f.read()
import os, hashlib, tempfile
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
import hashlib
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
return fp

View File

@@ -1,8 +1,9 @@
#!/usr/bin/env python
import io
import unittest
from tinygrad.helpers import getenv
from extra.utils import fetch
import io, unittest
import torch
import numpy as np
from tinygrad.helpers import getenv
from extra.utils import fetch, temp
from tinygrad.state import torch_load
from PIL import Image
@@ -22,10 +23,9 @@ class TestFetch(unittest.TestCase):
assert pimg.size == (705, 1024)
class TestUtils(unittest.TestCase):
def test_fake_torch_load_zipped(self):
import torch
import numpy as np
import tempfile
def test_fake_torch_load_zipped(self): self._test_fake_torch_load_zipped()
def test_fake_torch_load_zipped_float16(self): self._test_fake_torch_load_zipped(isfloat16=True)
def _test_fake_torch_load_zipped(self, isfloat16=False):
class LayerWithOffset(torch.nn.Module):
def __init__(self):
super(LayerWithOffset, self).__init__()
@@ -37,25 +37,22 @@ class TestUtils(unittest.TestCase):
d.as_strided([2, 2], [1, 2], storage_offset=4)
)
for isfloat16 in [True, False]:
model = torch.nn.Sequential(
torch.nn.Linear(4, 8),
torch.nn.Linear(8, 3),
LayerWithOffset()
)
if isfloat16: model = model.half()
model = torch.nn.Sequential(
torch.nn.Linear(4, 8),
torch.nn.Linear(8, 3),
LayerWithOffset()
)
if isfloat16: model = model.half()
with tempfile.TemporaryDirectory() as tmpdirname:
path = tmpdirname + '/testloadmodel.pth'
torch.save(model.state_dict(), path)
model2 = torch_load(path)
for name, a in model.state_dict().items():
b = model2[name]
a, b = a.numpy(), b.numpy()
assert a.shape == b.shape
assert a.dtype == b.dtype
assert np.array_equal(a, b)
path = temp(f"test_load_{isfloat16}.pt")
torch.save(model.state_dict(), path)
model2 = torch_load(path)
for name, a in model.state_dict().items():
b = model2[name]
a, b = a.numpy(), b.numpy()
assert a.shape == b.shape
assert a.dtype == b.dtype
assert np.array_equal(a, b)
if __name__ == '__main__':
unittest.main()

View File

@@ -5,7 +5,7 @@ import io
import unittest
import numpy as np
import onnx
from extra.utils import fetch
from extra.utils import fetch, temp
from extra.onnx import get_run_onnx
from tinygrad.tensor import Tensor
@@ -60,8 +60,8 @@ class TestOnnxModel(unittest.TestCase):
tinygrad_out = tinygrad_out.numpy()
pr.disable()
stats = pstats.Stats(pr)
stats.dump_stats("/tmp/net.prof")
os.system("flameprof /tmp/net.prof > /tmp/prof.svg")
stats.dump_stats(temp("net.prof"))
os.system(f"flameprof {temp('net.prof')} > {temp('prof.svg')}")
ps = stats.sort_stats(pstats.SortKey.TIME)
ps.print_stats(30)

View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python
import unittest
import numpy as np
from extra.utils import WINDOWS
from tinygrad.helpers import getenv
from tinygrad.jit import TinyJit
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import BatchNorm2d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm
@@ -93,6 +95,7 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
@unittest.skipIf(getenv("CI", "") != "" and WINDOWS, "runs out of memory in CI")
def test_conv_transpose2d(self):
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1

View File

@@ -6,7 +6,7 @@ from tinygrad.state import safe_load, safe_save, get_state_dict
from tinygrad.helpers import dtypes
from tinygrad.runtime.ops_disk import RawDiskBuffer
from extra.helpers import Timing
from extra.utils import fetch_as_file
from extra.utils import fetch_as_file, temp
from tinygrad.state import torch_load, get_state_dict
def compare_weights_both(url):
@@ -59,77 +59,77 @@ class TestSafetensors(unittest.TestCase):
"weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17,1,1),
"weight4": torch.arange(0, 2, dtype=torch.uint8),
}
save_file(tensors, "/tmp/model.safetensors")
save_file(tensors, temp("model.safetensors"))
ret = safe_load("/tmp/model.safetensors")
ret = safe_load(temp("model.safetensors"))
for k,v in tensors.items(): np.testing.assert_array_equal(ret[k].numpy(), v.numpy())
safe_save(ret, "/tmp/model.safetensors_alt")
with open("/tmp/model.safetensors", "rb") as f:
with open("/tmp/model.safetensors_alt", "rb") as g:
safe_save(ret, temp("model.safetensors_alt"))
with open(temp("model.safetensors"), "rb") as f:
with open(temp("model.safetensors_alt"), "rb") as g:
assert f.read() == g.read()
ret2 = safe_load("/tmp/model.safetensors_alt")
ret2 = safe_load(temp("model.safetensors_alt"))
for k,v in tensors.items(): np.testing.assert_array_equal(ret2[k].numpy(), v.numpy())
def test_efficientnet_safetensors(self):
from models.efficientnet import EfficientNet
model = EfficientNet(0)
state_dict = get_state_dict(model)
safe_save(state_dict, "/tmp/eff0")
state_dict_loaded = safe_load("/tmp/eff0")
safe_save(state_dict, temp("eff0"))
state_dict_loaded = safe_load(temp("eff0"))
assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys()))
for k,v in state_dict.items():
np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
# load with the real safetensors
from safetensors import safe_open
with safe_open("/tmp/eff0", framework="pt", device="cpu") as f:
with safe_open(temp("eff0"), framework="pt", device="cpu") as f:
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
for k in f.keys():
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
class TestDiskTensor(unittest.TestCase):
def test_empty(self):
pathlib.Path("/tmp/dt1").unlink(missing_ok=True)
Tensor.empty(100, 100, device="disk:/tmp/dt1")
pathlib.Path(temp("dt1")).unlink(missing_ok=True)
Tensor.empty(100, 100, device=f"disk:{temp('dt1')}")
def test_write_ones(self):
pathlib.Path("/tmp/dt2").unlink(missing_ok=True)
pathlib.Path(temp("dt2")).unlink(missing_ok=True)
out = Tensor.ones(10, 10, device="CPU")
outdisk = out.to("disk:/tmp/dt2")
outdisk = out.to(f"disk:{temp('dt2')}")
print(outdisk)
outdisk.realize()
del out, outdisk
# test file
with open("/tmp/dt2", "rb") as f:
with open(temp("dt2"), "rb") as f:
assert f.read() == b"\x00\x00\x80\x3F" * 100
# test load alt
reloaded = Tensor.empty(10, 10, device="disk:/tmp/dt2")
reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt2')}")
out = reloaded.numpy()
assert np.all(out == 1.)
def test_slice(self):
pathlib.Path("/tmp/dt3").unlink(missing_ok=True)
Tensor.arange(10, device="CPU").to("disk:/tmp/dt3").realize()
pathlib.Path(temp("dt3")).unlink(missing_ok=True)
Tensor.arange(10, device="CPU").to(f"disk:{temp('dt3')}").realize()
slice_me = Tensor.empty(10, device="disk:/tmp/dt3")
slice_me = Tensor.empty(10, device=f"disk:{temp('dt3')}")
print(slice_me)
is_3 = slice_me[3:4].cpu()
assert is_3.numpy()[0] == 3
def test_slice_2d(self):
pathlib.Path("/tmp/dt5").unlink(missing_ok=True)
Tensor.arange(100, device="CPU").to("disk:/tmp/dt5").realize()
slice_me = Tensor.empty(10, 10, device="disk:/tmp/dt5")
pathlib.Path(temp("dt5")).unlink(missing_ok=True)
Tensor.arange(100, device="CPU").to(f"disk:{temp('dt5')}").realize()
slice_me = Tensor.empty(10, 10, device=f"disk:{temp('dt5')}")
tst = slice_me[1].numpy()
print(tst)
np.testing.assert_allclose(tst, np.arange(10, 20))
def test_assign_slice(self):
pathlib.Path("/tmp/dt4").unlink(missing_ok=True)
cc = Tensor.arange(10, device="CPU").to("disk:/tmp/dt4").realize()
pathlib.Path(temp("dt4")).unlink(missing_ok=True)
cc = Tensor.arange(10, device="CPU").to(f"disk:{temp('dt4')}").realize()
#cc.assign(np.ones(10)).realize()
print(cc[3:5].numpy())