mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
more whitespace
This commit is contained in:
@@ -89,6 +89,7 @@ python -m pytest
|
||||
### TODO
|
||||
|
||||
* Train an EfficientNet
|
||||
* Make broadcasting work on the backward pass (simple please)
|
||||
* EfficientNet backward pass
|
||||
* Tensors on GPU (GPU support, must support Mac)
|
||||
* Reduce code
|
||||
|
||||
@@ -84,7 +84,9 @@ register('matmul', Dot)
|
||||
class Pad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
return np.pad(x, ((0,0), (0,0), (padding[0], padding[1]), (padding[2], padding[3])))
|
||||
return np.pad(x,
|
||||
((0,0), (0,0),
|
||||
(padding[0], padding[1]), (padding[2], padding[3])))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
@@ -163,7 +165,8 @@ class Conv2D(Function):
|
||||
ctx.stride = (ctx.stride, ctx.stride)
|
||||
cout,cin,H,W = w.shape
|
||||
ys,xs = ctx.stride
|
||||
bs,cin_,oy,ox = x.shape[0], x.shape[1], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
||||
bs,cin_ = x.shape[0], x.shape[1]
|
||||
oy,ox = (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
||||
assert cin*ctx.groups == cin_
|
||||
assert cout % ctx.groups == 0
|
||||
rcout = cout//ctx.groups
|
||||
@@ -171,7 +174,9 @@ class Conv2D(Function):
|
||||
gx = x.reshape(bs,ctx.groups,cin,x.shape[2],x.shape[3])
|
||||
tx = np.lib.stride_tricks.as_strided(gx,
|
||||
shape=(bs, ctx.groups, cin, oy, ox, H, W),
|
||||
strides=(gx.strides[0], gx.strides[1], gx.strides[2], gx.strides[3]*ys, gx.strides[4]*xs, gx.strides[3], gx.strides[4]),
|
||||
strides=(gx.strides[0], gx.strides[1], gx.strides[2],
|
||||
gx.strides[3]*ys, gx.strides[4]*xs,
|
||||
gx.strides[3], gx.strides[4]),
|
||||
writeable=False,
|
||||
)
|
||||
tw = w.reshape(ctx.groups, rcout, cin, H, W)
|
||||
|
||||
Reference in New Issue
Block a user