diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 41054519cc..a482e36afe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -112,7 +112,7 @@ jobs: - name: Run Pytest run: TORCH=1 python -m pytest -s -v -n=auto test/ - name: Run ONNX - run: TORCH=1 python -m pytest test/external/external_test_onnx_backend.py || true + run: TORCH=1 python -m pytest test/external/external_test_onnx_backend.py --tb=no --disable-warnings || true testgpu: name: GPU Tests diff --git a/datasets/imagenet_download.py b/datasets/imagenet_download.py new file mode 100644 index 0000000000..71d9e55cc6 --- /dev/null +++ b/datasets/imagenet_download.py @@ -0,0 +1,51 @@ +# Python version of https://gist.github.com/antoinebrl/7d00d5cb6c95ef194c737392ef7e476a +from extra.utils import download_file +from pathlib import Path +from tqdm import tqdm +import tarfile, os + +def imagenet_extract(file, path, small=False): + with tarfile.open(name=file) as tar: + if small: # Show progressbar only for big files + for member in tar.getmembers(): tar.extract(path=path, member=member) + else: + for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member) + tar.close() + +def imagenet_prepare_val(): + # Read in the labels file + with open(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt", 'r') as f: + labels = f.read().splitlines() + f.close() + # Get a list of images + images = os.listdir(Path(__file__).parent.parent / "datasets/imagenet/val") + images.sort() + # Create folders and move files into those + for co,dir in enumerate(labels): + os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val" / dir, exist_ok=True) + os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co], exist_ok=True) + os.remove(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt") + +def imagenet_prepare_train(): + images = os.listdir(Path(__file__).parent.parent / "datasets/imagenet/train") + for co,tarf in enumerate(images): + # for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file + if Path(Path(__file__).parent.parent / "datasets/imagenet/train" / images[co]).is_file(): + images[co] = tarf[:-4] # remove .tar from extracted tar files + os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/train" / images[co], exist_ok=True) + imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/train" / tarf, Path(__file__).parent.parent / "datasets/imagenet/train" / images[co], small=True) + os.remove(Path(__file__).parent.parent / "datasets/imagenet/train" / tarf) + +if __name__ == "__main__": + os.makedirs(Path(__file__).parent.parent / "datasets/imagenet", exist_ok=True) + os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val", exist_ok=True) + os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/train", exist_ok=True) + download_file("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent.parent / "datasets/imagenet/imagenet_class_index.json") + download_file("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt") + download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar") # 7GB + imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/val") + imagenet_prepare_val() + if os.getenv['IMGNET_TRAIN'] is not None: + download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar") #138GB! + imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/train") + imagenet_prepare_train() diff --git a/docs/abstractions.py b/docs/abstractions.py index 28d9e514c9..fbaecad46e 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -99,8 +99,8 @@ class LazyOp: arg: Optional[Any] = None # and an optional static argument # there's currently 27 Ops you have to implement for an accelerator. -class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto(); SIN = auto() -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() +class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() +class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() class ReduceOps(Enum): SUM = auto(); MAX = auto() class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() class FusedOps(Enum): MULACC = auto() @@ -158,7 +158,7 @@ class Interpreted: # and they have a lookup table to functions for the Ops fxn_for_op: Dict[Op, Callable] = { - UnaryOps.EXP: lambda x: np.exp(x), + UnaryOps.EXP2: lambda x: np.exp2(x), BinaryOps.ADD: lambda x,y: x+y} # Compiled backends take a little more (example: GPU and LLVM) diff --git a/docs/adding_new_accelerators.md b/docs/adding_new_accelerators.md index 8957435cfb..373f7a7a6d 100644 --- a/docs/adding_new_accelerators.md +++ b/docs/adding_new_accelerators.md @@ -7,7 +7,7 @@ It's pretty easy to add a new accelerator to tinygrad. All you need to do is imp These are the ops that you must implement for your accelerator of choice. Compiled Accelerators do not need to implement movement_ops, as they are handled b the ShapeTracker. ``` Buffer # class of memory on this device -unary_op (NOOP, EXP, LOG, CAST, SIN) # A -> A +unary_op (NOOP, EXP2, LOG2, CAST, SIN) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size) movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size) diff --git a/docs/env_vars.md b/docs/env_vars.md index e1cdbd27c5..790eed06df 100644 --- a/docs/env_vars.md +++ b/docs/env_vars.md @@ -184,3 +184,9 @@ CI | [1] | disables some tests for CI Variable | Possible Value(s) | Description ---|---|--- BS | [8, 16, 32, 64, 128] | batch size to use + +### datasets/imagenet_download.py + +Variable | Possible Value(s) | Description +---|---|--- +IMGNET_TRAIN | [1] | download also training data with imagenet diff --git a/examples/llama.py b/examples/llama.py index 8d5c5edadc..fc91e11d1e 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -10,7 +10,7 @@ from tqdm import tqdm np.set_printoptions(linewidth=200) from typing import Optional, Tuple -from tinygrad.helpers import getenv, DEBUG +from tinygrad.helpers import dtypes, getenv, DEBUG from tinygrad.lazy import Device from extra.helpers import Timing from tinygrad.tensor import Tensor @@ -143,14 +143,13 @@ class Transformer: # get only the part we are using. making it contiguous avoids more kernel calls freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize() - if seqlen > 1: mask = np.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=np.float32) mask = np.triu(mask, k=start_pos + 1) # TODO: this is hard to do in tinygrad mask = Tensor(mask) else: mask = None - + # mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1) if seqlen > 1 else None #TODO: Pending(#942) for layer in self.layers: h.realize() # TODO: why do i need this? h = layer(h, start_pos, freqs_cis, mask) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 04c3234b8b..ad250a2011 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -460,6 +460,7 @@ class CLIPTextTransformer: x = self.embeddings(input_ids, list(range(len(input_ids)))) causal_attention_mask = np.triu(np.ones((1,1,77,77), dtype=np.float32) * -np.inf, k=1) x = self.encoder(x, Tensor(causal_attention_mask, device=x.device)) + # x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1)) # TODO: Pending(#942) return self.final_layer_norm(x) # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) diff --git a/extra/onnx.py b/extra/onnx.py index 86fe64d0fb..896aab468c 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -3,8 +3,8 @@ from google.protobuf.internal.containers import RepeatedCompositeFieldContainer import importlib import numpy as np from tinygrad.tensor import Tensor -from tinygrad.helpers import prod -from tinygrad.helpers import getenv, DEBUG +from tinygrad.helpers import prod, getenv, DEBUG, dtypes +from typing import List,Dict from onnx.onnx_pb import AttributeProto, ModelProto, TensorProto try: from onnx.helper import tensor_dtype_to_np_dtype @@ -15,7 +15,7 @@ except ImportError: # global numpy cache for parameters numpy_cache = {} -def safe_numpy(t): +def safe_numpy(t) -> np.ndarray: if not isinstance(t, Tensor): return t global numpy_cache if t not in numpy_cache: @@ -56,7 +56,7 @@ def get_run_onnx(onnx_model: ModelProto): else: raise Exception(f"can't parse {a.type} {a}") def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a} - tensors = {} + tensors: Dict[str, Tensor] = {} # get weights and biases for inp in onnx_model.graph.initializer: @@ -83,32 +83,43 @@ def get_run_onnx(onnx_model: ModelProto): def run_onnx(inputs={}, debug=False): if getenv("DEBUGONNX"): debug = True - input_tensors = {} - intermediate_tensors = {} + input_tensors: Dict[str,Tensor] = {} + intermediate_tensors: Dict[str,Tensor] = {} output_tensor_names = [x.name for x in onnx_model.graph.output] # get inputs for inp in onnx_model.graph.input: if inp.name in tensors: continue - shape = shape_to_tuple(inp.type.tensor_type.shape) + tmp=inp.type.optional_type.elem_type.tensor_type if inp.type.HasField("optional_type") else (inp.type.sequence_type.elem_type.tensor_type if inp.type.HasField("sequence_type") else inp.type.tensor_type) + shape = shape_to_tuple(tmp.shape) if len(shape) >= 1 and shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size if inp.name in inputs: - input_shape = inputs[inp.name].shape - if input_shape == (0,): raise NotImplementedError("empty tensors aren't supported in tinygrad") - assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" if isinstance(inputs[inp.name], Tensor): input_tensors[inp.name] = inputs[inp.name] else: input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False) + input_shape = input_tensors[inp.name].shape + if input_shape == (0,): raise NotImplementedError("empty tensors aren't supported in tinygrad") + assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" for _,v in input_tensors.items(): v.realize() else: raise Exception(f"no data for {inp.name} with shape {shape}") + def fetch_tensor(x: str): + if x in tensors: return tensors[x] + if x in intermediate_tensors: return intermediate_tensors[x] + if x != str(): return input_tensors[x] + return None + for num,n in enumerate(onnx_model.graph.node): - inp = [tensors[x] if x in tensors else (intermediate_tensors[x] if x in intermediate_tensors else (input_tensors[x] if x != str() else None)) for x in n.input] + inp: List[Tensor] = [] + if debug: print("inputs:") + for x in n.input: + t = fetch_tensor(x) + if debug: print(f"\t{x} - {t}") + inp.append(t) opt = attribute_dict[num] if debug: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}") - # free ones if n.op_type == "Relu": ret = inp[0].relu() elif n.op_type == "Sigmoid": ret = inp[0].sigmoid() @@ -128,7 +139,7 @@ def get_run_onnx(onnx_model: ModelProto): elif 'value_int' in opt: ret = Tensor(np.array(opt['value_int'], dtype=np.int64), requires_grad=False) elif 'value_floats' in opt: ret = Tensor(np.array(opt['value_floats'], dtype=np.float32), requires_grad=False) elif 'value_ints' in opt: ret = Tensor(np.array(opt['value_ints'], dtype=np.int64), requires_grad=False) - else: raise NotImplementedError(f'Constant not implemented') + else: raise NotImplementedError(f'Constant not implemented for {opt}') elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))]) elif n.op_type == "Resize": # TODO: this is handcoded for YOLOv8 @@ -139,14 +150,14 @@ def get_run_onnx(onnx_model: ModelProto): ret = ret.reshape([x*y for x,y in zip(inp[0].shape, [int(x) for x in scales])]) elif n.op_type == "Gather": # TODO: is this correct? seems to work for simple gather ops - axis = opt['axis'] + axis = opt['axis'] if 'axis' in opt else 0 shape = list(inp[0].shape) indices = [shape[axis]+int(x) if x<0 else int(x) for x in safe_numpy(inp[1])] args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices] ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis) ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed elif n.op_type in ["Add", "Sub", "Mul", "Pow"]: - if (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)): + if all([isinstance(x, Tensor) for x in inp]) and (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)): inp[1] = inp[1].reshape(inp[0].shape) # TODO: is this right? if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))]) @@ -167,12 +178,12 @@ def get_run_onnx(onnx_model: ModelProto): elif n.op_type == "Slice": assert onnx_model_version >= 10, f'only onnx version >= 10 supported for slice' arg = [(0,x) for x in inp[0].shape] - starts, ends, axes = inp[1:4] - assert axes.shape == (1,) - axis, starts, ends = int(safe_numpy(axes)[0]), int(safe_numpy(starts)[0]), int(safe_numpy(ends)[0]) - ends = min(ends, inp[0].shape[axis]) - starts = starts + inp[0].shape[axis] if starts < 0 else starts - arg[axis] = (starts, ends) + starts, ends = inp[1:3] + axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]) + steps = safe_numpy(inp[4])[0] if len(inp) > 4 else 1 + starts, ends = safe_numpy(starts.cast(dtypes.int32)).tolist(), safe_numpy(ends.cast(dtypes.int32)).tolist() # TODO: when indexing is added use that + for i,axis in enumerate(axes.tolist()): + arg[axis] = (starts[i] if starts[i] >= 0 else inp[0].shape[axis]+starts[i], ends[i] if ends[i] >= 0 else inp[0].shape[axis]+ends[i]) ret = inp[0].slice(arg=arg) elif n.op_type == "Shrink": bias = opt['bias'] if 'bias' in opt else 0 @@ -192,7 +203,10 @@ def get_run_onnx(onnx_model: ModelProto): if not isinstance(ret, tuple): ret = (ret, ) assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}" if debug: print([x.shape if isinstance(x, Tensor) else None for x in ret]) - for i in range(len(n.output)): intermediate_tensors[n.output[i]] = ret[i] + if debug: print("outputs:") + for i in range(len(n.output)): + if debug: print(f"\t{n.output[i]} - {ret[i]}") + intermediate_tensors[n.output[i]] = ret[i] #print(ret[0].numpy().mean()) if num == ONNXLIMIT: output_tensor_names = n.output diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index ee0d617e7e..6e24e1daf7 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -75,7 +75,7 @@ def _padding(X, pads=None, auto_pad="NOTSET", axes=None, constant_value=0.): return zero_padded + constant_padder def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.): - assert mode == "constant" + assert mode == "constant", f"WARNING: Pad mode {mode} not implemented" constant_value = value if constant_value is None else constant_value.numpy() seq_pads = list(pads) if isinstance(pads, tuple) else pads.numpy().astype(np.int32).tolist() seq_axes = axes.numpy().astype(np.int32).tolist() if axes is not None else None @@ -92,7 +92,7 @@ def AveragePool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_p return padding_included / div def MaxPool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1): - assert ceil_mode == 0 and storage_order == 0 + assert ceil_mode == 0 and storage_order == 0, f"WARNING: MaxPool ceil_mode {ceil_mode} and storage_order {storage_order} not implemented" return _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-2:]).max_pool2d(kernel_shape, stride=strides, dilation=dilations) def Conv(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1): @@ -110,8 +110,8 @@ def Dropout(data, ratio=0.5, training_mode=False, seed=None): mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device) return data * mask * (1/(1.0 - ratio)), mask -def Shape(data, end=None, start=0): return list(data.shape)[start:end] -def Size(data): return prod(data.shape) +def Shape(data, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64) +def Size(data): return prod(data if isinstance(data, list) else data.shape) # TODO: this doesn't match Tensor.flatten behavior def Flatten(input, axis=1): @@ -145,7 +145,7 @@ def HardSwish(input): return input * HardSigmoid(input, 1/6, 0.5) def Celu(X, alpha=1.0): return X.relu() - (-alpha*(X/alpha).exp()+1).relu() def Selu(X, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu()) def Softplus(X): return X.softplus() -def PRelu(X, slope): return X.leakyrelu(slope) +def PRelu(X:Tensor, slope:Tensor): return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope def LeakyRelu(X, alpha=0.01): return X.leakyrelu(alpha) def ThresholdedRelu(X, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha def Softmax_1(input, axis=1): return input.softmax(axis) @@ -191,6 +191,8 @@ def ReduceLogSumExp(data, axes=None, keepdims=1, noop_with_empty_axes=0): return def GlobalAveragePool(X): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True) def GlobalMaxPool(X): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True) +def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool) +def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32) def Tile(input, repeats): repeats_ = [int(x) for x in safe_numpy(repeats)] @@ -200,13 +202,17 @@ def Tile(input, repeats): return input.reshape(new_shape).expand(expand_shape).reshape(final_shape) def Range(start, limit, delta): return Tensor.arange(safe_numpy(limit)[0], safe_numpy(start)[0], safe_numpy(delta)[0]) -def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y) +def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype) def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool) def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(dtypes.bool) def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) +def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1): + k = int(k.numpy().item()) if k is not 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0 + return x.triu(k) if upper else x.tril(k) + def ConstantOfShape(input, value:Tensor=None): if value is None: value=Tensor([0.0]) shape = [int(x) for x in safe_numpy(input)] @@ -228,3 +234,21 @@ def MeanVarianceNormalization(input, axis=(0, 2, 3)): data_mean = input.mean(axis=axis, keepdim=True) std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt() return (input - data_mean) / (std + 1e-9) + +def NegativeLogLikelihoodLoss(input, target, weight=None, ignore_index=None, reduction="mean"): + N, C, i_shape = input.shape[0], input.shape[1], input.shape + t_shape = target.shape + if len(input.shape) != 3: + input = input.reshape((N, C, -1)) + target = target.reshape((N, -1)) + if weight is not None: + mask = target.unsqueeze(-1) == Tensor.arange(C,dtype=dtypes.int64).repeat((N, 1, 1)) + weight = (mask * weight).sum(axis=-1) + if ignore_index is not None: + cond = (target == ignore_index) + weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1) + mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2)) + loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight) + if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum() + elif reduction == "sum": return loss.sum() + return loss.reshape(t_shape) if len(i_shape) != 3 else loss \ No newline at end of file diff --git a/test/external/external_metal_uaf.py b/test/external/external_metal_uaf.py new file mode 100644 index 0000000000..46b89460e8 --- /dev/null +++ b/test/external/external_metal_uaf.py @@ -0,0 +1,17 @@ +import weakref +import numpy as np +from tinygrad.tensor import Tensor, Device +Device.DEFAULT = "METAL" + +if __name__ == "__main__": + t = Tensor.zeros(3).realize() + wt = weakref.ref(t.lazydata.realized) + n = t.numpy() + t += 1 + n2 = t.numpy() + print(wt) + del t + print(wt) + print(n, n.base, n.base.base) + print(n2, n2.base, n2.base.base) + assert wt() is not None \ No newline at end of file diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 8fc42cc619..1f4def4cea 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -92,18 +92,9 @@ backend_test.exclude('test_asin_*') backend_test.exclude('test_asinh_*') backend_test.exclude('test_atan_*') backend_test.exclude('test_atanh_*') -# backend_test.include('test_cos_*') -# backend_test.include('test_cosh_*') -# backend_test.exclude('test_sin_*') -# backend_test.include('test_sinh_*') -# backend_test.include('test_tanh_*') # no boolean ops (2d, 3d, 4d) -# backend_test.exclude('test_and*') -# backend_test.exclude('test_xor*') -# backend_test.exclude('test_or*') backend_test.exclude('test_bitshift_*') -# backend_test.include('test_not_*') # no scatter gather backend_test.exclude('test_gather_*') @@ -133,8 +124,10 @@ backend_test.exclude('test_bitwise_*') backend_test.exclude('test_blackmanwindow_*') backend_test.exclude('test_bernoulli_*') backend_test.exclude('test_cumsum_*') -backend_test.exclude('test_tril_*') -backend_test.exclude('test_triu_*') + +backend_test.exclude('test_tril_zero_cpu') # TODO: zero array support +backend_test.exclude('test_triu_zero_cpu') # TODO: zero array support + backend_test.exclude('test_col2im_*') backend_test.exclude('test_hammingwindow_*') backend_test.exclude('test_hannwindow_*') diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index c840fc9150..4ac2f6557b 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -85,7 +85,7 @@ class TestInferenceMinKernels(unittest.TestCase): args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} model = Transformer(**args_tiny) for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - with CLCache(85): + with CLCache(94): model(Tensor([[1,2,3,4]]), 0).realize() @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") diff --git a/test/test_ops.py b/test/test_ops.py index 28ba1bb40e..6054cf6c64 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,5 +1,6 @@ import torch import time +import math import numpy as np import unittest from tinygrad.tensor import Tensor @@ -124,6 +125,18 @@ class TestOps(unittest.TestCase): tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) + def test_tril(self): + helper_test_op([(3,3)], lambda x: x.tril(), lambda x: x.tril()) + helper_test_op([(3,3)], lambda x: x.tril(1), lambda x: x.tril(1)) + helper_test_op([(3,3)], lambda x: x.tril(-1), lambda x: x.tril(-1)) + helper_test_op([(5,3,3)], lambda x: x.tril(), lambda x: x.tril()) + helper_test_op([(5,3,3)], lambda x: x.tril(1), lambda x: x.tril(1)) + def test_triu(self): + helper_test_op([(3,3)], lambda x: x.triu(), lambda x: x.triu()) + helper_test_op([(3,3)], lambda x: x.triu(1), lambda x: x.triu(1)) + helper_test_op([(3,3)], lambda x: x.triu(-1), lambda x: x.triu(-1)) + helper_test_op([(5,3,3)], lambda x: x.triu(), lambda x: x.triu()) + helper_test_op([(5,3,3)], lambda x: x.triu(1), lambda x: x.triu(1)) def test_maximum(self): helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum) helper_test_op([(), ()], torch.maximum, Tensor.maximum) @@ -941,6 +954,17 @@ class TestOps(unittest.TestCase): helper_test_op([(4,4)], lambda x: x[:, 1:2][0:1]) helper_test_op([(4,4)], lambda x: x[:, 1:2][:, 0:1]) + @unittest.skip("this test is broken #862") + def test_max_inf(self): + n = Tensor([1, float("nan")]).max().numpy() + assert math.isnan(n.item()), f"{n.item()} is not nan" + + @unittest.skip("this test is broken #942") + def test_inf_where(self): + x = Tensor.full((3, 3), float("inf")) + n = (x < 0).where(x, 1).numpy() + assert np.all(n == 1.) + if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 1e0def9bff..eed3fc2c0c 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -11,8 +11,8 @@ class TestSymbolic(unittest.TestCase): def test_ge(self): self.helper_test_variable(Variable("a", 3, 8)>=77, 0, 0, "0") self.helper_test_variable(Variable("a", 3, 8)>=9, 0, 0, "0") - self.helper_test_variable(Variable("a", 3, 8)>=8, 0, 1, "(a>=8)") - self.helper_test_variable(Variable("a", 3, 8)>=4, 0, 1, "(a>=4)") + self.helper_test_variable(Variable("a", 3, 8)>=8, 0, 1, "((a*-1)<-7)") + self.helper_test_variable(Variable("a", 3, 8)>=4, 0, 1, "((a*-1)<-3)") self.helper_test_variable(Variable("a", 3, 8)>=3, 1, 1, "1") self.helper_test_variable(Variable("a", 3, 8)>=2, 1, 1, "1") diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 0633986760..590478b586 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, import math, collections from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps -from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored, prod +from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored, prod from tinygrad.runtime.lib import RawConst from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode from tinygrad.lazy import LazyBuffer @@ -12,8 +12,6 @@ render_cl = render_python.copy() render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})" render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" -NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass - class CStyleLanguage(NamedTuple): kernel_prefix: str = "" buffer_prefix: str = "" @@ -48,8 +46,8 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F return idx, idy code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})", - UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})", + UnaryOps.EXP2: lambda x: f"exp2({x})", + UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index bb028b220d..4fc21b545e 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -6,22 +6,21 @@ from tinygrad.helpers import dtypes from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps from tinygrad.lazy import LazyBuffer -from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode +from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode def int_const(x): return ir.Constant(ir.IntType(64), x) render_llvm = { NumNode: lambda self,ops,ctx: int_const(self.b), MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)), DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)), ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)), - GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)), LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)), SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)) } code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), - UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), + UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)), + UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)), BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), @@ -88,11 +87,11 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str: val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args.i], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0)) else: val = bb[-1].load(bb[-1].gep(func.args[args.i], [idx], inbounds=True)) - if func_dtypes[args.i] != ir.FloatType(): + if func_dtypes[args.i] != ir.FloatType(): if dtypes.is_int(bufs[args.i].dtype): val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].sitofp(val, ir.FloatType()) else: - val = bb[-1].fpext(val, ir.FloatType()) + val = bb[-1].fpext(val, ir.FloatType()) lvars[newvar] = val if uop == UOps.STORE: assert args.valid.min == 1, "store must be valid" @@ -101,7 +100,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str: if func_dtypes[0] != ir.FloatType(): if dtypes.is_int(bufs[args.i].dtype): element = bb[-1].fptoui(element, func_dtypes[0]) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].fptosi(element, func_dtypes[0]) - else: + else: element = bb[-1].fptrunc(element, func_dtypes[0]) bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True)) if uop == UOps.ALU: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 450282b3ab..6201b82d4d 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -76,7 +76,7 @@ class dtypes: def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name] bool: Final[DType] = DType(0, 1, "bool", bool) float16: Final[DType] = DType(0, 2, "half", np.float16) - float32: Final[DType] = DType(1, 4, "float", np.float32) + float32: Final[DType] = DType(4, 4, "float", np.float32) int8: Final[DType] = DType(0, 1, "char", np.int8) int32: Final[DType] = DType(1, 4, "int", np.int32) int64: Final[DType] = DType(2, 8, "int64", np.int64) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 57f10b4890..bdb91f3cbc 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -37,14 +37,14 @@ class Relu(Function): class Log(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x - return x.unary_op(UnaryOps.LOG) + return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)/math.log(math.e))) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.binary_op(BinaryOps.DIV, self.x) class Exp(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.unary_op(UnaryOps.EXP) + self.ret = x.binary_op(BinaryOps.MUL, x.const_like(math.log(math.e)/math.log(2))).unary_op(UnaryOps.EXP2) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: @@ -128,7 +128,7 @@ class Pow(Function): def backward(self, grad_output:LazyBuffer): return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \ - grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None + grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2)/math.log(math.e))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None class Div(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 720b22cafd..7dbd675088 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -8,7 +8,7 @@ from tinygrad.runtime.lib import RawBuffer, RawConst # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly -class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto(); SIN = auto() # noqa: E702 +class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class FusedOps(Enum): MULACC = auto() # noqa: E702 diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 313e13b0a4..32aa00a380 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -31,7 +31,8 @@ class RawBufferCopyIn(RawBuffer): class RawBufferMapped(RawBufferCopyIn): def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented") - def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=self.dtype.np) + # NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688 + def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1)) # this one is simple enough that i moved it out of the runtimes diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 413acafd78..24c4e2620f 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -27,7 +27,7 @@ def einsum_mulacc(einsum, get_strides, expand): return mulacc numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP: np.exp, UnaryOps.LOG: np.log, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin, + UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin, BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to, MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)], diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 3f08f90a10..35c1bad572 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -10,7 +10,7 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch. inverse_type_map = {v:k for k,v in type_map.items()} torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, + UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)), diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index bcde187fbb..f0178791e9 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -25,7 +25,7 @@ class Node: def __neg__(self): return self*-1 def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)]) def __sub__(self, b:Union[Node, int]): return self+-b - def __ge__(self, b:int): return create_node(GeNode(self, b)) + def __ge__(self, b:int): return create_node(LtNode(-self, -b+1)) def __lt__(self, b:int): return create_node(LtNode(self, b)) def __mul__(self, b:int): if b == 0: return NumNode(0) @@ -125,16 +125,12 @@ def create_node(ret:Node): return ret class OpNode(Node): - def __init__(self, a:Node, b:int): + def __init__(self, a:Node, b:int): self.a, self.b = a, b self.min, self.max = self.get_bounds() - @abstractmethod + @abstractmethod def get_bounds(self) -> Tuple[int, int]: pass -class GeNode(OpNode): - def __mul__(self, b: int): return (self.a*b) >= (self.b*b) - def __floordiv__(self, b: int, _=False): return (self.a//b) >= (self.b//b) - def get_bounds(self) -> Tuple[int, int]: return int(self.a.min >= self.b), int(self.a.max >= self.b) class LtNode(OpNode): def __mul__(self, b: int): return (self.a*b) < (self.b*b) def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b) @@ -148,18 +144,18 @@ class MulNode(OpNode): def __mod__(self, b: int): a = (self.a * (self.b%b)) return Node.__mod__(a, b) - def get_bounds(self) -> Tuple[int, int]: + def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) class DivNode(OpNode): def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div - def get_bounds(self) -> Tuple[int, int]: + def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 return self.a.min//self.b, self.a.max//self.b class ModNode(OpNode): def __floordiv__(self, b: int, factoring_allowed=True): if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod return Node.__floordiv__(self, b, factoring_allowed) - def get_bounds(self) -> Tuple[int, int]: + def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) @@ -194,7 +190,7 @@ class SumNode(RedNode): if m > 1 and b%m == 0: return (self//m)//(b//m) return Node.__floordiv__(self, b, factoring_allowed) - def __mod__(self, b: int): + def __mod__(self, b: int): new_nodes = [] for x in self.nodes: if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b)) @@ -202,7 +198,7 @@ class SumNode(RedNode): else: new_nodes.append(x) return Node.__mod__(Variable.sum(new_nodes), b) -class AndNode(RedNode): +class AndNode(RedNode): def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes]) def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes]) @@ -218,7 +214,6 @@ render_python: Dict[Type, Callable] = { MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})", DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})", ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})", - GeNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}>={self.b})", LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})", SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2cc3d64db5..69f6c98ff3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -143,7 +143,7 @@ class Tensor: def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs) @staticmethod - def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step).cumsum() + (start - step) + def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step, **kwargs).cumsum() + (start - step) @staticmethod def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs): @@ -493,7 +493,7 @@ class Tensor: def cumsum(self, axis=0): x = self.permute(*(i for i in range(self.ndim) if i != axis), axis) - return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis]), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1)) + return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1)) # ***** mlops (unary) ***** @@ -505,6 +505,12 @@ class Tensor: def sin(self): return mlops.Sin.apply(self) def cos(self): return ((math.pi/2)-self).sin() def tan(self): return self.sin() / self.cos() + + @staticmethod + def _tri(r:int, c:int, k:int=0) -> Tensor: return Tensor.arange(r).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k).unsqueeze(0).expand(r,c) + def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k).where(self, Tensor.zeros_like(self)) + def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1).where(Tensor.zeros_like(self), self) + # ***** math functions (unary) ***** def __neg__(self): return 0.0-self