mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
[tensor][perf] Replace list comprehension with *. (#1102)
It's more concise, idiomatic and faster: ``` In [8]: %timeit [1 for _ in range(100)] 2.12 µs ± 26.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) In [9]: %timeit [1] * 100 515 ns ± 5.23 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) ```
This commit is contained in:
@@ -464,8 +464,8 @@ class Tensor:
|
||||
#x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1 for _ in range(len(oyx))], cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
|
||||
Reference in New Issue
Block a user