mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
adding tuples is fine
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
from __future__ import annotations
|
||||
import functools, itertools
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
|
||||
HLOP = getenv("HLOP", 0)
|
||||
@@ -328,12 +328,12 @@ class Tensor:
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1)
|
||||
if len(self.shape) > 1:
|
||||
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
|
||||
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
|
||||
else:
|
||||
order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
|
||||
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
# bs x groups*cin x H x W
|
||||
|
||||
Reference in New Issue
Block a user