mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix backward convs (#746)
* fix backward convs * no pushing in reduce * late cout * test_fold_4convs_sgd
This commit is contained in:
28
test/external/external_test_opt.py
vendored
28
test/external/external_test_opt.py
vendored
@@ -193,6 +193,34 @@ class TestOpt(unittest.TestCase):
|
||||
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(optim.get_parameters([c1, c2]))
|
||||
with CLCache(allowed=9):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_4convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(optim.get_parameters([c1, c2, c3, c4]))
|
||||
with CLCache(allowed=19):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
|
||||
@@ -30,6 +30,7 @@ def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
|
||||
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape: Tuple[int, ...] = self.shape
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
|
||||
@@ -178,7 +179,8 @@ class LazyBuffer:
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, (self,), new_shape), self.dtype)
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, tuple(srcs), new_shape), self.dtype)
|
||||
|
||||
# shrink -> stride -> permute -> reshape -> pad -> expand
|
||||
def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
|
||||
@@ -242,28 +244,29 @@ class LazyBuffer:
|
||||
|
||||
return ret
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
|
||||
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
new_srcs = []
|
||||
for x in srcs:
|
||||
mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = []
|
||||
bx = x
|
||||
# backwalk all the movement ops. don't push PAD or EXPAND
|
||||
while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1:
|
||||
assert isinstance(bx.op.op, MovementOps)
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
bx = bx.op.src[0]
|
||||
# NOTE: can't push pads with a div
|
||||
if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))):
|
||||
new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
return tuple(new_srcs)
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
|
||||
if SHUFFLE_MOVEMENT_OPS:
|
||||
new_srcs = []
|
||||
did_replace = False
|
||||
for x in srcs:
|
||||
mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = []
|
||||
bx = x
|
||||
# backwalk all the movement ops. don't push PAD or EXPAND
|
||||
while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1:
|
||||
assert isinstance(bx.op.op, MovementOps)
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
bx = bx.op.src[0]
|
||||
# NOTE: can't push pads with a div
|
||||
if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))):
|
||||
new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1]))
|
||||
did_replace = True
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
if did_replace: return elementwise_op(op, *new_srcs, arg=arg)
|
||||
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
||||
|
||||
# get outputs now
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
|
||||
|
||||
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
||||
if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
||||
|
||||
@@ -182,6 +182,7 @@ class Tensor:
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
if not any(x.requires_grad for x in t0._ctx.parents):
|
||||
del t0._ctx # TODO: does it help to delete this here ever?
|
||||
continue
|
||||
assert (t0.grad is not None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
@@ -303,7 +304,7 @@ class Tensor:
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
||||
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1, _insert_dims=tuple()) -> Tensor:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
||||
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
@@ -311,7 +312,10 @@ class Tensor:
|
||||
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
||||
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
||||
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
|
||||
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
|
||||
xup = self.reshape(*prefix, *([1]*len(_insert_dims)), *flatten((1,i) for i in i_)).expand(*prefix, *_insert_dims, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *_insert_dims, *[e*i for e,i in zip(e_, i_)])
|
||||
# NOTE: _insert_dims is required because reduces can't be merged (yet)
|
||||
prefix += _insert_dims
|
||||
slc_prefix += [(0,x) for x in _insert_dims]
|
||||
# slide by dilation
|
||||
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
|
||||
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
||||
@@ -325,7 +329,11 @@ class Tensor:
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
if len(_insert_dims):
|
||||
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
prefix += _insert_dims
|
||||
slc_prefix += [(0,x) for x in _insert_dims]
|
||||
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
||||
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
||||
|
||||
@@ -335,15 +343,19 @@ class Tensor:
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
(bs,cin_,_,_), (cout,cin,H,W) = self.shape, weight.shape
|
||||
assert cin*groups == cin_, f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({cin*groups} vs. {cin_})"
|
||||
assert groups*cin == cin_, f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
||||
|
||||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad2d(padding_)._pool((H,W), stride, dilation)
|
||||
x = self.pad2d(padding_)._pool((H,W), stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
||||
rcout, oy, ox = cout//groups, x.shape[2], x.shape[3]
|
||||
x = x.reshape(bs, groups, cin, 1, oy, ox, H, W).expand(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
|
||||
|
||||
oy, ox, rcout = x.shape[2], x.shape[3], cout//groups
|
||||
# NOTE: we do this expand explicitly so the permute isn't pushed in the binop
|
||||
x = x.reshape(bs, groups, 1, cin, oy, ox, H, W).expand(bs, groups, rcout, cin, oy, ox, H, W).permute(0,1,2,4,5,3,6,7)
|
||||
# expand the channels with the pool
|
||||
# TODO: this reduces the number of kernels, but it's slower!
|
||||
#x = self.pad2d(padding_)._pool((H,W), stride, dilation, _insert_dims=(cout//groups,)) # (bs, groups*cin, rcout, oy, ox, H, W)
|
||||
#rcout, oy, ox = x.shape[2:5]
|
||||
#x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, oy, ox, cin, H, W)
|
||||
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1), keepdim=True).reshape(bs, cout, oy, ox)
|
||||
|
||||
Reference in New Issue
Block a user