mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Rewrote Tensor.cat to be shorter and (hopefully) clearer (#372)
* Rewrote Tensor.cat to be shorter and (hopefully) clearer * Use cumsum[-1] instead of separate sum
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import inspect, functools, importlib
|
||||
import inspect, functools, importlib, itertools
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod
|
||||
from typing import List, Tuple, Callable, Optional
|
||||
@@ -139,28 +139,17 @@ class Tensor:
|
||||
assert s.step is None or s.step == 1
|
||||
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
|
||||
|
||||
# TODO: there has to be a cleaner way to write this
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
for y in args: assert len(self.shape) == len(y.shape)
|
||||
for y in args:
|
||||
assert len(y.shape) == len(self.shape)
|
||||
assert all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
|
||||
args = [self] + list(args)
|
||||
s = [[] for _ in range(len(args))]
|
||||
for i in range(len(self.shape)):
|
||||
if i != dim:
|
||||
for y in args: assert self.shape[i] == y.shape[i]
|
||||
for j in range(len(args)):
|
||||
s[j].append((0, self.shape[i]))
|
||||
else:
|
||||
shape_sum = 0
|
||||
for y in args: shape_sum += y.shape[i]
|
||||
k = 0
|
||||
for j,y in enumerate(args):
|
||||
s[j].append((-k, shape_sum-k))
|
||||
k += y.shape[i]
|
||||
ret = self.slice(arg=s[0])
|
||||
for ts,y in zip(s[1:], args[1:]):
|
||||
ret += y.slice(arg=ts)
|
||||
return ret
|
||||
shape_cumsum = [0, *itertools.accumulate(y.shape[dim] for y in args)]
|
||||
slc = [[(0, s) for s in self.shape] for _ in args]
|
||||
for s,k in zip(slc, shape_cumsum): s[dim] = (-k, shape_cumsum[-1]-k)
|
||||
slices = [arg.slice(arg=s) for arg,s in zip(args, slc)]
|
||||
return functools.reduce(Tensor.__iadd__, slices)
|
||||
|
||||
def matmul(self:Tensor, w:Tensor):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
|
||||
Reference in New Issue
Block a user