mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
* fix: revive torch backend
* as_strided view vs copy
* Revert "as_strided view vs copy"
This reverts commit 82a61223f2.
* add extra tests (move inplace, add fusion tests)
* better fusion with inplace_op
* no optimizer hooks (break mnist training fusion)
* split off fusion tests in separate file, assert on resnet fusion
fix: remove comments
* cleanup, reduce diff
* reduce diff
* better fusion and identity checks
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
29 lines
1.3 KiB
Python
29 lines
1.3 KiB
Python
from PIL import Image
|
|
from tinygrad.helpers import getenv, GlobalCounters
|
|
import torch, torchvision, pathlib, warnings
|
|
import torchvision.transforms as transforms
|
|
import extra.torch_backend.backend
|
|
device = "tiny"
|
|
torch.set_default_device(device)
|
|
|
|
if __name__ == "__main__":
|
|
GlobalCounters.reset()
|
|
img = Image.open(pathlib.Path(__file__).parent.parent.parent / "test/models/efficientnet/Chicken.jpg").convert('RGB')
|
|
transform = transforms.Compose([
|
|
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
img = transform(img).unsqueeze(0).to(device)
|
|
|
|
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
|
|
if getenv("EVAL", 1): model.eval()
|
|
out = model(img).detach().cpu().numpy()
|
|
print("output:", out.shape, out.argmax())
|
|
assert out.argmax() == 7 # cock
|
|
|
|
kernel_count = GlobalCounters.kernel_count
|
|
assert kernel_count > 0, "No kernels, test failed"
|
|
expected_kernels = 228
|
|
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
|
|
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
|
assert kernel_count <= expected_kernels, f"{expectation}" |