diff --git a/tinygrad/nn.py b/tinygrad/nn.py index feba126221..74ff8a6c44 100644 --- a/tinygrad/nn.py +++ b/tinygrad/nn.py @@ -3,7 +3,7 @@ import numpy as np class BatchNorm2D: def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1): - assert affine == True + assert affine == True, "BatchNorm2D is only supported with affine" self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz) diff --git a/tinygrad/ops/ops_cpu.py b/tinygrad/ops/ops_cpu.py index a454f56f1d..0686f8100a 100644 --- a/tinygrad/ops/ops_cpu.py +++ b/tinygrad/ops/ops_cpu.py @@ -2,22 +2,15 @@ import numpy as np from ..tensor import Function class CPUBuffer(np.ndarray): - def log(x): - return np.log(x) - def exp(x): - return np.exp(x) - def relu(x): - return np.maximum(x, 0) - def expand(x, shp): - return np.broadcast_to(x, shp) + log = lambda x: np.log(x) + exp = lambda x: np.exp(x) + relu = lambda x: np.maximum(x, 0) + expand = lambda x,shp: np.broadcast_to(x, shp) + permute = lambda x,order: x.transpose(order) + type = lambda x,tt: x.astype(tt) + custompad = lambda x,padding: np.pad(x, padding) def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs) - def permute(x, order): - return x.transpose(order) - def type(x, tt): - return x.astype(tt) - def custompad(x, padding): - return np.pad(x, padding) def toCPU(x): return x @staticmethod diff --git a/tinygrad/ops/ops_torch.py b/tinygrad/ops/ops_torch.py index c1b04d852a..7eaa661aff 100644 --- a/tinygrad/ops/ops_torch.py +++ b/tinygrad/ops/ops_torch.py @@ -5,8 +5,7 @@ from ..tensor import Function device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TorchBuffer(torch.Tensor): - def custompad(x, padding): - return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]) + custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]) @staticmethod def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False)).to(device) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d31ec84fb3..bfa3af1312 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,5 +1,5 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py -import os, atexit, time, inspect, functools +import os, atexit, time, inspect, functools, importlib from collections import defaultdict import numpy as np @@ -30,12 +30,13 @@ class ProfileOp: return self def __exit__(self, *junk): if GRAPH: - saved_tensors = filter(lambda x: any([isinstance(x, v) for v in Device.buffers.values()]), self.ctx.saved_tensors) + # connect inputs to outputs for x in self.x: for y in self.output: G.add_edge(id(x.data), id(y.data), label=self.name, color="blue" if self.backward else "black") G.nodes[id(x.data)]['label'], G.nodes[id(y.data)]['label'] = str(x.shape), str(y.shape) # which saved tensors does this backward depend on? + saved_tensors = filter(lambda x: any([isinstance(x, v) for v in Device.buffers.values()]), self.ctx.saved_tensors) if self.backward: for x in saved_tensors: for y in self.output: @@ -243,12 +244,12 @@ class Tensor: def sum(self, axis=None, keepdim=False): axis, out_shape = self._canonicalize_reduce_axis(axis) ret = self._sum(axis=axis) - return ret if keepdim else ret.reshape(shape=out_shape) + return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape) def max(self, axis=None, keepdim=False): axis, out_shape = self._canonicalize_reduce_axis(axis) ret = self._max(axis=axis) - return ret if keepdim else ret.reshape(shape=out_shape) + return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape) def mean(self, axis=None, keepdim=False): out = self.sum(axis=axis, keepdim=keepdim) @@ -410,7 +411,6 @@ def _register_ops(namespace, device=Device.CPU): if name.endswith("Buffer"): Device.buffers[device] = cls elif name[0] != "_": register(name.lower(), cls, device=device) -import importlib for d,ops in Device.imports.items(): try: _register_ops(importlib.import_module('tinygrad.ops.'+ops), d)