diff --git a/examples/gpt2.py b/examples/gpt2.py index e9a18c67c3..8226d5f048 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -35,7 +35,7 @@ class Attention: self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize() # update the cache - self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack([xk, xv])).realize() + self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize() if start_pos > 0: keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) diff --git a/examples/llama3.py b/examples/llama3.py index 4a8fc3637c..25e1b74e20 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -92,7 +92,7 @@ def NF4Linear(block_size): -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0, ] - CODE = Tensor.stack([Tensor(c) for c in _CODE]) + CODE = Tensor.stack(*[Tensor(c) for c in _CODE]) class _NF4Linear: def __init__(self, in_features, out_features, bias=False): assert not bias, "bias not supported" @@ -103,7 +103,7 @@ def NF4Linear(block_size): def __call__(self, x: Tensor) -> Tensor: high_bits = self.weight low_bits = (self.weight * 2 ** 4).contiguous() - unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False) + unpacked = Tensor.stack(high_bits, low_bits, dim=-1).div(2 ** 4, upcast=False) unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale return x.linear(unscaled.reshape(self.out_features, self.in_features).T) diff --git a/examples/mamba.py b/examples/mamba.py index 1282f07a1a..8127f207bc 100644 --- a/examples/mamba.py +++ b/examples/mamba.py @@ -85,7 +85,7 @@ def selective_scan_ref( if i == u.shape[2] - 1: last_state = x ys.append(y) - y = Tensor.stack(ys, dim=2) # (batch dim L) + y = Tensor.stack(*ys, dim=2) # (batch dim L) out = y if D is None else y + u * D.reshape((-1, 1)) if z is not None: out = out * z.silu() diff --git a/examples/mask_rcnn.py b/examples/mask_rcnn.py index a23ef6bd1a..00d4c240c3 100644 --- a/examples/mask_rcnn.py +++ b/examples/mask_rcnn.py @@ -151,7 +151,7 @@ class Masker: for mask, box in zip(masks, boxes.bbox) ] if len(res) > 0: - res = torch.stack(res, dim=0)[:, None] + res = torch.stack(*res, dim=0)[:, None] else: res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) return Tensor(res.numpy()) diff --git a/examples/so_vits_svc.py b/examples/so_vits_svc.py index b8872a04ce..9cab45f0ab 100644 --- a/examples/so_vits_svc.py +++ b/examples/so_vits_svc.py @@ -468,7 +468,7 @@ def repeat_expand_2d_left(content, target_len): # content : [h, t] if i >= temp[current_pos+1]: current_pos += 1 cols.append(content[:, current_pos]) - return Tensor.stack(cols).transpose(0, 1) + return Tensor.stack(*cols).transpose(0, 1) def load_fairseq_cfg(checkpoint_path): assert Path(checkpoint_path).is_file() diff --git a/examples/yolov8.py b/examples/yolov8.py index 5af2e6de71..1aa0b7facc 100644 --- a/examples/yolov8.py +++ b/examples/yolov8.py @@ -34,7 +34,7 @@ def preprocess(im, imgsz=640, model_stride=32, model_pt=True): same_shapes = all(x.shape == im[0].shape for x in im) auto = same_shapes and model_pt im = Tensor([compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im]) - im = Tensor.stack(im) if im.shape[0] > 1 else im + im = Tensor.stack(*im) if im.shape[0] > 1 else im im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) im /= 255 # 0 - 255 to 0.0 - 1.0 return im @@ -180,7 +180,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1) sy = sy.reshape(-1, 1).repeat([1, w]).reshape(-1) - anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2)) + anchor_points.append(Tensor.stack(sx, sy, dim=-1).reshape(-1, 2)) stride_tensor.append(Tensor.full((h * w), stride)) anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2]) stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1) diff --git a/extra/models/bert.py b/extra/models/bert.py index dfb2f34d89..31cc8366e8 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -37,7 +37,7 @@ class BertForQuestionAnswering: start_logits = start_logits.reshape(-1, 1) end_logits = end_logits.reshape(-1, 1) - return Tensor.stack([start_logits, end_logits]) + return Tensor.stack(start_logits, end_logits) class BertForMLPerf: def __init__(self, hidden_size:int, intermediate_size:int, max_position_embeddings:int, num_attention_heads:int, num_hidden_layers:int, type_vocab_size:int, vocab_size:int, attention_probs_dropout_prob:float, hidden_dropout_prob:float) -> None: diff --git a/extra/models/llama.py b/extra/models/llama.py index 0dbf248766..75ea16a3ff 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -6,7 +6,7 @@ from tinygrad.helpers import getenv def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim)) freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) - return Tensor.stack([freqs.cos().half(), freqs.sin().half()], dim=-1).reshape(1, end, 1, dim//2, 2) + return Tensor.stack(freqs.cos().half(), freqs.sin().half(), dim=-1).reshape(1, end, 1, dim//2, 2) # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc) def complex_mult(A, c, d): @@ -72,7 +72,7 @@ class Attention: # update the cache assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}" - self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack([xk, xv])).realize() + self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize() keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv diff --git a/extra/models/mask_rcnn.py b/extra/models/mask_rcnn.py index fb9d599e52..0ed0d3077c 100644 --- a/extra/models/mask_rcnn.py +++ b/extra/models/mask_rcnn.py @@ -66,7 +66,7 @@ def get_strides(shape): # with keys as integer array for all axes def tensor_getitem(tensor, *keys): # something about ints is broken with gpu, cuda - flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32) + flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32) strides = get_strides(tensor.shape) idxs = (flat_keys * strides).sum(1) gatherer = npgather if USE_NP_GATHER else _gather @@ -255,7 +255,7 @@ class BoxList: bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1] bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2] bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3] - self.bbox = Tensor.stack((bb1, bb2, bb3, bb4), dim=1) + self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1) if remove_empty: box = self.bbox keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) @@ -394,7 +394,7 @@ class AnchorGenerator: shift_y, shift_x = meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) - shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1) anchors.append( (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4) @@ -525,7 +525,7 @@ class BoxCoder(object): targets_dw = ww * Tensor.log(gt_widths / ex_widths) targets_dh = wh * Tensor.log(gt_heights / ex_heights) - targets = Tensor.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1) return targets def decode(self, rel_codes, boxes): @@ -556,7 +556,7 @@ class BoxCoder(object): y = pred_ctr_y - 0.5 * pred_h w = pred_ctr_x + 0.5 * pred_w - 1 h = pred_ctr_y + 0.5 * pred_h - 1 - pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1]) + pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1]) return pred_boxes @@ -632,8 +632,8 @@ class RPNPostProcessor: box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx])) concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx])) - box_regression = Tensor.stack(box_regression_list) - concat_anchors = Tensor.stack(concat_anchors_list) + box_regression = Tensor.stack(*box_regression_list) + concat_anchors = Tensor.stack(*concat_anchors_list) proposals = self.box_coder.decode( box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4) diff --git a/extra/models/rnnt.py b/extra/models/rnnt.py index 3902fe2313..8382aae6ac 100644 --- a/extra/models/rnnt.py +++ b/extra/models/rnnt.py @@ -145,7 +145,7 @@ class LSTM: new_hc = [x] for i, cell in enumerate(self.cells): new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i])) - return Tensor.stack(new_hc[1:]).realize() + return Tensor.stack(*new_hc[1:]).realize() class StackTime: diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index cc9d91b392..9603638300 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -587,7 +587,7 @@ def AffineGrid(theta: Tensor, size: Tensor, align_corners=0): if dim == 0: stackable = [a.reshape(dim_sz, *[1]*(len(data_sz)-1)) + size_zeros, *stackable] elif dim == 1: stackable = [a.reshape(1, dim_sz, *[1]*(len(data_sz)-2)) + size_zeros, *stackable] else: stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable] - original_grid = Tensor.stack(stackable, dim=len(data_sz)) + original_grid = Tensor.stack(*stackable, dim=len(data_sz)) if original_grid.ndim == 3: N, dim_2d, dim_homo = theta.shape assert dim_2d == 2 and dim_homo == 3 diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index 7ab5756ef3..cafb8883c7 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -131,13 +131,13 @@ class TestIndexing(unittest.TestCase): # indexing with step reference = consec((10, 10, 10)) - numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack([reference[1], reference[3]], 0)) - numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack([reference[1], reference[3], reference[5]], 0)) - numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack([reference[1], reference[5]], 0)) - numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack([reference[2:4, 1], reference[2:4, 3]], 1)) - numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0)) - numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0)) - numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1)) + numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack(reference[1], reference[3], dim=0)) + numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack(reference[1], reference[3], reference[5], dim=0)) + numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack(reference[1], reference[5], dim=0)) + numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack(reference[2:4, 1], reference[2:4, 3], dim=1)) + numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack(reference[3, 1], reference[3, 3], reference[3, 5], dim=0)) + numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack(reference[2, 1], reference[2, 5], dim=0).unsqueeze(0)) + numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack(reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5], dim=1)) lst = [list(range(i, i+10)) for i in range(0, 100, 10)] tensor = Tensor(lst) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 48e0171cde..e3a110afff 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -393,7 +393,7 @@ class TestLinearizer(unittest.TestCase): def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() - r = Tensor.stack([a, b]) + r = Tensor.stack(a, b) k = Linearizer(*create_schedule([r.lazydata])[-1].ast) k.upcast() @@ -890,7 +890,7 @@ class TestHandCodedOpts(unittest.TestCase): assert k.upcasted == 1 and k.full_shape[-1] == 7 def test_masked_upcast_wino(self): - monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)]) + monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)]) s = create_schedule([monster.lazydata])[-1] k = Linearizer(*s.ast) diff --git a/test/test_ops.py b/test/test_ops.py index ec5832d2e8..883b00c06c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1062,9 +1062,9 @@ class TestOps(unittest.TestCase): lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:]) def test_stack_slice(self): - helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack([x for i in range(3)])[0,:]) - helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack([x for i in range(3)])[0,0]) - helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack([x for i in range(4)])[3]) + helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack(*[x for i in range(3)])[0,:]) + helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack(*[x for i in range(3)])[0,0]) + helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack(*[x for i in range(4)])[3]) def test_transpose(self): helper_test_op([(3,3)], lambda x: x.T) @@ -1554,13 +1554,13 @@ class TestOps(unittest.TestCase): def test_stack(self): for dim in range(-1, 3): - helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack([x, y, z], dim)) + helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack(x, y, z, dim=dim)) with self.assertRaises(IndexError): - Tensor.stack([Tensor.randn(45, 65, 3)], dim=77) + Tensor.stack(Tensor.randn(45, 65, 3), dim=77) a = Tensor(3.14) - np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy()) + np.testing.assert_allclose(Tensor.stack(a, a).numpy(), Tensor([3.14, 3.14]).numpy()) def test_repeat(self): x = Tensor.randn(4, 6, 3) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f654b405ec..deebecd573 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -977,22 +977,20 @@ class Tensor: for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d) return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) - @staticmethod - def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor: + def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """ - Concatenates a sequence of tensors along a new dimension. + Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`. ```python exec="true" source="above" session="tensor" result="python" t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6]) - print(Tensor.stack([t0, t1, t2], dim=0).numpy()) + print(t0.stack(t1, t2, dim=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(Tensor.stack([t0, t1, t2], dim=1).numpy()) + print(t0.stack([t1, t2], dim=1).numpy()) ``` """ - unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors] # checks for shapes and number of dimensions delegated to cat - return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim) + return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim) def repeat(self, repeats, *args) -> Tensor: """