diff --git a/models/convnext.py b/models/convnext.py new file mode 100644 index 0000000000..f79071133d --- /dev/null +++ b/models/convnext.py @@ -0,0 +1,64 @@ +from tinygrad.tensor import Tensor +from tinygrad.nn import Conv2d, LayerNorm, Linear + +class Block: + def __init__(self, dim): + self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = Linear(dim, 4 * dim) + self.pwconv2 = Linear(4 * dim, dim) + self.gamma = Tensor.ones(dim) + + def __call__(self, x:Tensor): + return x + x.sequential([ + self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm, + self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2) + ]) + +class ConvNeXt: + def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]): + self.downsample_layers = [ + [Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm((dims[0], 1, 1), eps=1e-6)], + *[[LayerNorm((dims[i], 1, 1), eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)] + ] + self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))] + self.norm = LayerNorm(dims[-1]) + self.head = Linear(dims[-1], num_classes) + + def __call__(self, x:Tensor): + for downsample, stage in zip(self.downsample_layers, self.stages): + x = x.sequential(downsample).sequential(stage) + return x.mean([-2, -1]).sequential([self.norm, self.head]) + +# *** model definition is done *** + +versions = { + "tiny": {"depths": [3, 3, 9, 3], "dims": [96, 192, 384, 768]}, + "small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]}, + "base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]}, + "large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]}, + "xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]} +} + +def get_model(version, load_weights=False): + model = ConvNeXt(**versions[version]) + if load_weights: + from extra.utils import fetch, fake_torch_load, get_child + weights = fake_torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model'] + for k,v in weights.items(): + mv = get_child(model, k) + mv.assign(v.reshape(mv.shape)).realize() + return model + +if __name__ == "__main__": + model = get_model("tiny", True) + + # load image + from test.models.test_efficientnet import chicken_img, preprocess, _LABELS + img = Tensor(preprocess(chicken_img)) + + Tensor.training = False + Tensor.no_grad = True + + out = model(img).numpy() + print(_LABELS[out.argmax()]) diff --git a/test/models/test_efficientnet.py b/test/models/test_efficientnet.py index 709cec9ae6..d7f0ea1821 100644 --- a/test/models/test_efficientnet.py +++ b/test/models/test_efficientnet.py @@ -6,6 +6,7 @@ import unittest import numpy as np from PIL import Image +from tinygrad.helpers import getenv from models.efficientnet import EfficientNet from models.vit import ViT from tinygrad.tensor import Tensor @@ -54,7 +55,7 @@ car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg') class TestEfficientNet(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = EfficientNet(number=0) + cls.model = EfficientNet(number=getenv("NUM")) cls.model.load_from_pretrained() @classmethod diff --git a/test/models/test_train.py b/test/models/test_train.py index 00df14aa52..e0c2778dd8 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -5,6 +5,7 @@ from tinygrad.nn import optim from tinygrad.tensor import Device from tinygrad.helpers import getenv from extra.training import train +from models.convnext import ConvNeXt from models.efficientnet import EfficientNet from models.transformer import Transformer from models.vit import ViT @@ -24,18 +25,32 @@ def train_one_step(model,X,Y): et = time.time()-st print("done in %.2f ms" % (et*1000.)) +def check_gc(): + if Device.DEFAULT == "GPU": + from extra.introspection import print_objects + assert print_objects() == 0 + class TestTrain(unittest.TestCase): + def test_convnext(self): + model = ConvNeXt(depths=[1], dims=[16]) + X = np.zeros((BS,3,224,224), dtype=np.float32) + Y = np.zeros((BS), dtype=np.int32) + train_one_step(model,X,Y) + check_gc() + def test_efficientnet(self): model = EfficientNet(0) X = np.zeros((BS,3,224,224), dtype=np.float32) Y = np.zeros((BS), dtype=np.int32) train_one_step(model,X,Y) + check_gc() def test_vit(self): model = ViT() X = np.zeros((BS,3,224,224), dtype=np.float32) Y = np.zeros((BS,), dtype=np.int32) train_one_step(model,X,Y) + check_gc() def test_transformer(self): # this should be small GPT-2, but the param count is wrong @@ -44,10 +59,7 @@ class TestTrain(unittest.TestCase): X = np.zeros((BS,6), dtype=np.float32) Y = np.zeros((BS,6), dtype=np.int32) train_one_step(model,X,Y) - - if Device.DEFAULT == "GPU": - from extra.introspection import print_objects - assert print_objects() == 0 + check_gc() def test_resnet(self): X = np.zeros((BS, 3, 224, 224), dtype=np.float32) @@ -56,6 +68,7 @@ class TestTrain(unittest.TestCase): model = resnet_v() model.load_from_pretrained() train_one_step(model, X, Y) + check_gc() def test_bert(self): # TODO: write this diff --git a/test/test_tensor.py b/test/test_tensor.py index 7f9e22fe1a..013d67122c 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -149,9 +149,9 @@ class TestTinygrad(unittest.TestCase): for random_fn in [Tensor.randn, Tensor.uniform, Tensor.scaled_uniform, Tensor.glorot_uniform]: with self.subTest(msg=f"Tensor.{random_fn.__name__}"): Tensor.manual_seed(1337) - a = random_fn(10,10) + a = random_fn(10,10).realize() Tensor.manual_seed(1337) - b = random_fn(10,10) + b = random_fn(10,10).realize() np.testing.assert_allclose(a.numpy(), b.numpy()) if __name__ == '__main__': diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py index 38b33a2412..41bd88855c 100644 --- a/tinygrad/codegen/gpu.py +++ b/tinygrad/codegen/gpu.py @@ -325,13 +325,13 @@ class GPUCodegen(ASTKernel): [") {\n"] + self.kernel) # kernel function definition - function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.full_shape]) + function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape]) # painfully name the function if prg in GPUCodegen.kernel_name_cache: function_name = GPUCodegen.kernel_name_cache[prg] else: GPUCodegen.kernel_cnt[function_name] += 1 - if GPUCodegen.kernel_cnt[function_name]: function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}" + if GPUCodegen.kernel_cnt[function_name]: function_name = f"{function_name}{'n'+str(GPUCodegen.kernel_cnt[function_name])}" GPUCodegen.kernel_name_cache[prg] = function_name return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete, diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 7b397aa1c4..ae96f3f3c9 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -78,6 +78,13 @@ def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tu assert y.op in BinaryOps or y.op in UnaryOps return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src]) # type: ignore +class LazyNumpyArray: + def __init__(self, fxn, shape): self.fxn, self.shape = fxn, shape + def __call__(self): return self.fxn(self.shape) + def reshape(self, new_shape): return LazyNumpyArray(self.fxn, new_shape) + def copy(self): return self + def astype(self, typ): return self + def support_weakref(x): return x @support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class class LazyBuffer: @@ -115,7 +122,7 @@ class LazyBuffer: if self.realized is None: # get real ops first if self.op.op == LoadOps.FROMCPU: - self.realized = Device._buffers[self.device].fromCPU(self.op.arg) + self.realized = Device._buffers[self.device].fromCPU(self.op.arg() if isinstance(self.op.arg, LazyNumpyArray) else self.op.arg) ast = LazyOp(self.op.op, tuple()) elif self.op.op == LoadOps.CONTIGUOUS: real_src = self.op.src[0].realize(self.device) @@ -149,6 +156,7 @@ class LazyBuffer: assert isinstance(self.realized, Device._buffers[self.device]) return self.realized + # NOTE: we have to make a copy of the numpy array here in case the user changes it. expose this? @staticmethod def fromCPU(x, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy())) def toCPU(self): return self.realize().toCPU() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 31b9193040..05e019c51b 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -39,21 +39,19 @@ class BatchNorm2d: # TODO: is this good weight init? class Conv2d: - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else (kernel_size[0], kernel_size[1]) - self.stride = (stride, stride) if isinstance(stride, int) else (stride[0], stride[1]) - self.padding = (padding, ) * 4 if isinstance(padding, int) else ((padding[0], padding[0], padding[1], padding[1]) if len(padding) == 2 else padding) - # TODO: why is this realize needed? shouldn't it realize on the first run? - self.weight = Tensor.glorot_uniform(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1]).realize() - self.bias = Tensor.zeros(out_channels).contiguous().realize() if bias else None + self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups + self.weight = Tensor.glorot_uniform(out_channels, in_channels//groups, self.kernel_size[0], self.kernel_size[1]) + self.bias = Tensor.zeros(out_channels) if bias else None def __call__(self, x): - return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride) + return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups) class Linear: def __init__(self, in_features, out_features, bias=True): - self.weight = Tensor.glorot_uniform(out_features, in_features).realize() - self.bias = Tensor.zeros(out_features).contiguous().realize() if bias else None + self.weight = Tensor.glorot_uniform(out_features, in_features) + self.bias = Tensor.zeros(out_features) if bias else None def __call__(self, x): return x.linear(self.weight.transpose(), self.bias) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6c0eb9bb07..fd30e238dc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -152,6 +152,7 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method return cls.raw_buffer_type(4*prod(shape)) if backing is None else cls.raw_buffer_type.fromCPU(backing) def raw(self) -> RawBuffer: if self._buf is None: + if DEBUG >= 4 and self._backing is not None: print(f"**** copy in {self._backing.shape} to {type(self)}") self._buf = self.create_raw_buffer(self._base_shape, self._backing) self._backing = None return self._buf @@ -160,6 +161,7 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method def fromCPU(cls, x:np.ndarray) -> CompiledBuffer: return cls(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel()) def toCPU(self) -> np.ndarray: assert GlobalCounters.cache is None, f"can't copy out {self} while caching" + if DEBUG >= 3: print(f"**** copy out {self.shape}") return self.contiguous().raw().toCPU().reshape(self.shape) codegen_type : Any diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b7bd2146be..9d3db5c942 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -4,7 +4,7 @@ import math, functools, itertools import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten -from tinygrad.lazy import Device, LazyBuffer +from tinygrad.lazy import Device, LazyBuffer, LazyNumpyArray from tinygrad.image import image_conv2d_decorator # An instantiation of the Function is the Context @@ -40,7 +40,7 @@ class Tensor: # TODO: this has to realize, it shouldn't have to data = data.realize().toCPU() - if isinstance(data, np.ndarray): + if isinstance(data, (np.ndarray, LazyNumpyArray)): data = data if data.shape else data.reshape((1,)) self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device) elif isinstance(data, LazyBuffer): @@ -79,7 +79,7 @@ class Tensor: def assign(self, x) -> Tensor: if not isinstance(x, Tensor): x = Tensor(x) - assert self.shape == x.shape + assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized @@ -132,11 +132,11 @@ class Tensor: def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed) @staticmethod - def rand(*shape, **kwargs) -> Tensor: return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs) + def rand(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda shape: Tensor._rng.random(size=shape, dtype=np.float32), shape), **kwargs) # TODO: replace with a transformation from uniform -> gaussian @staticmethod - def randn(*shape, **kwargs) -> Tensor: return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs) + def randn(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda shape: Tensor._rng.standard_normal(size=shape, dtype=np.float32), shape), **kwargs) # ***** rng hlops ***** @@ -325,7 +325,7 @@ class Tensor: padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]]) # conv2d is a pooling op (with padding) - x = self.pad2d(padding_)._pool((H,W),stride, dilation) + x = self.pad2d(padding_)._pool((H,W), stride, dilation) oy, ox, rcout = x.shape[2], x.shape[3], cout//groups # NOTE: we do this expand explicitly so the permute isn't pushed in the binop