adding tuples is fine

This commit is contained in:
George Hotz
2023-02-23 19:42:48 -08:00
parent 661812ffef
commit 4c54adeb18

View File

@@ -2,8 +2,8 @@
from __future__ import annotations
import functools, itertools
import numpy as np
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG
from tinygrad.lazy import Device, LazyBuffer
HLOP = getenv("HLOP", 0)
@@ -328,12 +328,12 @@ class Tensor:
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W