diff --git a/docs/nn.md b/docs/nn.md index 7be0159fc8..f4fcc5164e 100644 --- a/docs/nn.md +++ b/docs/nn.md @@ -3,6 +3,7 @@ ::: tinygrad.nn.BatchNorm2d ::: tinygrad.nn.Conv1d ::: tinygrad.nn.Conv2d +::: tinygrad.nn.ConvTranspose1d ::: tinygrad.nn.ConvTranspose2d ::: tinygrad.nn.Linear ::: tinygrad.nn.GroupNorm @@ -26,4 +27,4 @@ ::: tinygrad.nn.state.get_state_dict ::: tinygrad.nn.state.get_parameters ::: tinygrad.nn.state.load_state_dict -::: tinygrad.nn.state.torch_load \ No newline at end of file +::: tinygrad.nn.state.torch_load diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index c04c1644f1..42318ca7b4 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -5,6 +5,24 @@ from tinygrad.helpers import prod from tinygrad.nn import optim, state # noqa: F401 class BatchNorm2d: + """ + Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension). + + - Described: https://paperswithcode.com/method/batch-normalization + - Paper: https://arxiv.org/abs/1502.03167v3 + + See: `Tensor.batchnorm` + + ```python exec="true" source="above" session="tensor" result="python" + norm = nn.BatchNorm2d(3) + t = Tensor.rand(2, 3, 4, 4) + print(t.mean().item(), t.std().item()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = norm(t) + print(t.mean().item(), t.std().item()) + ``` + """ def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1): self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum @@ -38,9 +56,38 @@ class BatchNorm2d: # TODO: these Conv lines are terrible def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + """ + Applies a 1D convolution over an input signal composed of several input planes. + + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d + + ```python exec="true" source="above" session="tensor" result="python" + conv = nn.Conv1d(1, 1, 3) + t = Tensor.rand(1, 1, 4) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = conv(t) + ``` + """ return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias) class Conv2d: + """ + Applies a 2D convolution over an input signal composed of several input planes. + + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d + + ```python exec="true" source="above" session="tensor" result="python" + conv = nn.Conv2d(1, 1, 3) + t = Tensor.rand(1, 1, 4, 4) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = conv(t) + print(t.numpy()) + ``` + """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups @@ -55,9 +102,39 @@ class Conv2d: return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5)) def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): + """ + Applies a 1D transposed convolution operator over an input signal composed of several input planes. + + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d + + ```python exec="true" source="above" session="tensor" result="python" + conv = nn.ConvTranspose1d(1, 1, 3) + t = Tensor.rand(1, 1, 4) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = conv(t) + print(t.numpy()) + ``` + """ return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias) class ConvTranspose2d(Conv2d): + """ + Applies a 2D transposed convolution operator over an input image. + + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d + + ```python exec="true" source="above" session="tensor" result="python" + conv = nn.ConvTranspose2d(1, 1, 3) + t = Tensor.rand(1, 1, 4, 4) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = conv(t) + print(t.numpy()) + ``` + """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.output_padding = output_padding @@ -70,6 +147,21 @@ class ConvTranspose2d(Conv2d): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5)) class Linear: + """ + Applies a linear transformation to the incoming data. + + See: https://pytorch.org/docs/stable/generated/torch.nn.Linear + + ```python exec="true" source="above" session="tensor" result="python" + lin = nn.Linear(3, 4) + t = Tensor.rand(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = lin(t) + print(t.numpy()) + ``` + """ def __init__(self, in_features, out_features, bias=True): # TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features)) self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5)) @@ -80,6 +172,22 @@ class Linear: return x.linear(self.weight.transpose(), self.bias) class GroupNorm: + """ + Applies Group Normalization over a mini-batch of inputs. + + - Described: https://paperswithcode.com/method/group-normalization + - Paper: https://arxiv.org/abs/1803.08494v3 + + ```python exec="true" source="above" session="tensor" result="python" + norm = nn.GroupNorm(2, 12) + t = Tensor.rand(2, 12, 4, 4) * 2 + 1 + print(t.mean().item(), t.std().item()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = norm(t) + print(t.mean().item(), t.std().item()) + ``` + """ def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True): self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None @@ -95,6 +203,22 @@ class GroupNorm: return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2)) class InstanceNorm: + """ + Applies Instance Normalization over a mini-batch of inputs. + + - Described: https://paperswithcode.com/method/instance-normalization + - Paper: https://arxiv.org/abs/1607.08022v3 + + ```python exec="true" source="above" session="tensor" result="python" + norm = nn.InstanceNorm(3) + t = Tensor.rand(2, 3, 4, 4) * 2 + 1 + print(t.mean().item(), t.std().item()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = norm(t) + print(t.mean().item(), t.std().item()) + ``` + """ def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True): self.num_features, self.eps = num_features, eps self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None @@ -106,6 +230,22 @@ class InstanceNorm: return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2)) class LayerNorm: + """ + Applies Layer Normalization over a mini-batch of inputs. + + - Described: https://paperswithcode.com/method/layer-normalization + - Paper: https://arxiv.org/abs/1607.06450v1 + + ```python exec="true" source="above" session="tensor" result="python" + norm = nn.LayerNorm(3) + t = Tensor.rand(2, 5, 3) * 2 + 1 + print(t.mean().item(), t.std().item()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = norm(t) + print(t.mean().item(), t.std().item()) + ``` + """ def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True): self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape) self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine @@ -118,9 +258,34 @@ class LayerNorm: return x * self.weight + self.bias class LayerNorm2d(LayerNorm): + """ + Applies Layer Normalization over a mini-batch of 2D inputs. + + See: `LayerNorm` + + ```python exec="true" source="above" session="tensor" result="python" + norm = nn.LayerNorm2d(3) + t = Tensor.rand(2, 3, 4, 4) * 2 + 1 + print(t.mean().item(), t.std().item()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = norm(t) + print(t.mean().item(), t.std().item()) + ``` + """ def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class Embedding: + """ + A simple lookup table that stores embeddings of a fixed dictionary and size. + + See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding + + ```python exec="true" source="above" session="tensor" result="python" + emb = nn.Embedding(10, 3) + print(emb(Tensor([1, 2, 3, 1])).numpy()) + ``` + """ def __init__(self, vocab_size:int, embed_size:int): self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index c428dc985b..fd9c612bcf 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -5,6 +5,9 @@ from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes, least_upper_dtype class Optimizer: + """ + Base class for all optimizers. + """ def __init__(self, params: List[Tensor], lr: float): # if it's None, but being put into an optimizer, set it to True for x in params: @@ -19,10 +22,20 @@ class Optimizer: dtype=least_upper_dtype(dtypes.default_float, dtypes.float32)) def zero_grad(self): + """ + Zeroes the gradients of all the parameters. + """ for param in self.params: param.grad = None - def step(self): Tensor.realize(*self.schedule_step()) + def step(self): + """ + Performs a single optimization step. + """ + Tensor.realize(*self.schedule_step()) def schedule_step(self) -> List[Tensor]: + """ + Returns the tensors that need to be realized to perform a single optimization step. + """ assert Tensor.training, ( f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer. - help: Consider setting Tensor.training=True before calling Optimizer.step().""") @@ -30,6 +43,9 @@ class Optimizer: def _step(self) -> List[Tensor]: raise NotImplementedError class OptimizerGroup(Optimizer): + """ + Combines multiple optimizers into one. + """ def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called self.optimizers = optimizers self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers]) @@ -39,9 +55,22 @@ class OptimizerGroup(Optimizer): # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD. def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False): + """ + Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay. + + `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule. + + - Described: https://paperswithcode.com/method/sgd + """ return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0) class LARS(Optimizer): + """ + Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay. + + - Described: https://paperswithcode.com/method/lars + - Paper: https://arxiv.org/abs/1708.03888v3 + """ def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001): super().__init__(params, lr) self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef @@ -70,13 +99,33 @@ class LARS(Optimizer): return self.b # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W. -def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True) -def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True) +def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01): + """ + AdamW optimizer with optional weight decay. + + - Described: https://paperswithcode.com/method/adamw + - Paper: https://arxiv.org/abs/1711.05101v3 + """ + return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True) +def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): + """ + Adam optimizer. + + - Described: https://paperswithcode.com/method/adam + - Paper: https://arxiv.org/abs/1412.6980 + """ + return LAMB(params, lr, b1, b2, eps, 0.0, adam=True) class LAMB(Optimizer): - def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False): + """ + LAMB optimizer with optional weight decay. + + - Described: https://paperswithcode.com/method/lamb + - Paper: https://arxiv.org/abs/1904.00962 + """ + def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False): super().__init__(params, lr) - self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, wd, adam + self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2]) self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params] self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params] diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 57148768a6..413fb15c4b 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -12,11 +12,21 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: + """ + Loads a .safetensor file from disk, returning the data, metadata length, and metadata. + """ t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") json_len = t[0:8].bitcast(dtypes.int64).item() return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()) def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: + """ + Loads a .safetensor file from disk, returning the state_dict. + + ```python + state_dict = nn.state.safe_load("test.safetensor") + ``` + """ t, json_len, metadata = safe_load_metadata(fn) ret = {} for k,v in metadata.items(): @@ -27,6 +37,14 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: return ret def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None): + """ + Saves a state_dict to disk in a .safetensor file with optional metadata. + + ```python + t = nn.Tensor([1, 2, 3]) + nn.state.safe_save({'t':t}, "test.safetensor") + ``` + """ headers, offset = {}, 0 if metadata: headers['__metadata__'] = metadata for k,v in tensors.items(): @@ -44,6 +62,19 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any from collections import OrderedDict def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]: + """ + Returns a state_dict of the object, with optional prefix. + + ```python exec="true" source="above" session="tensor" result="python" + class Net: + def __init__(self): + self.l1 = nn.Linear(4, 5) + self.l2 = nn.Linear(5, 6) + + net = Net() + print(nn.state.get_state_dict(net).keys()) + ``` + """ if isinstance(obj, tensor_type): return {prefix.strip('.'):obj} if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type) @@ -54,9 +85,35 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]: elif isinstance(obj, dict): for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) return state_dict -def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values()) +def get_parameters(obj) -> List[Tensor]: + """ + ```python exec="true" source="above" session="tensor" result="python" + class Net: + def __init__(self): + self.l1 = nn.Linear(4, 5) + self.l2 = nn.Linear(5, 6) + + net = Net() + print(len(nn.state.get_parameters(net))) + ``` + """ + return list(get_state_dict(obj).values()) def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None: + """ + Loads a state_dict into a model. + + ```python + class Net: + def __init__(self): + self.l1 = nn.Linear(4, 5) + self.l2 = nn.Linear(5, 6) + + net = Net() + state_dict = nn.state.get_state_dict(net) + nn.state.load_state_dict(net, state_dict) + ``` + """ start_mem_used = GlobalCounters.mem_used with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501 model_state_dict = get_state_dict(model) @@ -76,6 +133,13 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr # torch support! def torch_load(fn:str) -> Dict[str, Tensor]: + """ + Loads a torch .pth file from disk. + + ```python + state_dict = nn.state.torch_load("test.pth") + ``` + """ t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") offsets: Dict[Union[str, int], int] = {} diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 442c08fd83..c410298880 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -74,7 +74,7 @@ class Tensor: A `Tensor` is a multi-dimensional matrix containing elements of a single data type. ```python exec="true" session="tensor" - from tinygrad import Tensor, dtypes + from tinygrad import Tensor, dtypes, nn import numpy as np import math np.set_printoptions(precision=4)