fix backward convs (#746)

* fix backward convs

* no pushing in reduce

* late cout

* test_fold_4convs_sgd
This commit is contained in:
George Hotz
2023-04-14 10:42:11 -07:00
committed by GitHub
parent f7f416d6f4
commit 17e37157b6
3 changed files with 72 additions and 29 deletions

View File

@@ -193,6 +193,34 @@ class TestOpt(unittest.TestCase):
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
Tensor.training = False
def test_fold_2convs_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(optim.get_parameters([c1, c2]))
with CLCache(allowed=9):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
Tensor.training = False
def test_fold_4convs_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(optim.get_parameters([c1, c2, c3, c4]))
with CLCache(allowed=19):
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
opt.step()
Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):
# TODO: with Tensor.training
Tensor.training = True

View File

@@ -30,6 +30,7 @@ def _ast_reduceops(self:LazyBuffer) -> LazyOp:
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape: Tuple[int, ...] = self.shape
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
@@ -178,7 +179,8 @@ class LazyBuffer:
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, (self,), new_shape), self.dtype)
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, tuple(srcs), new_shape), self.dtype)
# shrink -> stride -> permute -> reshape -> pad -> expand
def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
@@ -242,28 +244,29 @@ class LazyBuffer:
return ret
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs = []
for x in srcs:
mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = []
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
bx = bx.op.src[0]
# NOTE: can't push pads with a div
if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))):
new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1]))
else:
new_srcs.append(x)
return tuple(new_srcs)
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
if SHUFFLE_MOVEMENT_OPS:
new_srcs = []
did_replace = False
for x in srcs:
mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = []
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
bx = bx.op.src[0]
# NOTE: can't push pads with a div
if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))):
new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1]))
did_replace = True
else:
new_srcs.append(x)
if did_replace: return elementwise_op(op, *new_srcs, arg=arg)
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
# get outputs now
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):

View File

@@ -182,6 +182,7 @@ class Tensor:
for t0 in reversed(self.deepwalk()):
if not any(x.requires_grad for x in t0._ctx.parents):
del t0._ctx # TODO: does it help to delete this here ever?
continue
assert (t0.grad is not None)
grads = t0._ctx.backward(t0.grad.lazydata)
@@ -303,7 +304,7 @@ class Tensor:
# ***** processing ops *****
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1, _insert_dims=tuple()) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
@@ -311,7 +312,10 @@ class Tensor:
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
xup = self.reshape(*prefix, *([1]*len(_insert_dims)), *flatten((1,i) for i in i_)).expand(*prefix, *_insert_dims, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *_insert_dims, *[e*i for e,i in zip(e_, i_)])
# NOTE: _insert_dims is required because reduces can't be merged (yet)
prefix += _insert_dims
slc_prefix += [(0,x) for x in _insert_dims]
# slide by dilation
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
@@ -325,7 +329,11 @@ class Tensor:
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
if len(_insert_dims):
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
prefix += _insert_dims
slc_prefix += [(0,x) for x in _insert_dims]
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
@@ -335,15 +343,19 @@ class Tensor:
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
(bs,cin_,_,_), (cout,cin,H,W) = self.shape, weight.shape
assert cin*groups == cin_, f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({cin*groups} vs. {cin_})"
assert groups*cin == cin_, f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
# conv2d is a pooling op (with padding)
x = self.pad2d(padding_)._pool((H,W), stride, dilation)
x = self.pad2d(padding_)._pool((H,W), stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oy, ox = cout//groups, x.shape[2], x.shape[3]
x = x.reshape(bs, groups, cin, 1, oy, ox, H, W).expand(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
oy, ox, rcout = x.shape[2], x.shape[3], cout//groups
# NOTE: we do this expand explicitly so the permute isn't pushed in the binop
x = x.reshape(bs, groups, 1, cin, oy, ox, H, W).expand(bs, groups, rcout, cin, oy, ox, H, W).permute(0,1,2,4,5,3,6,7)
# expand the channels with the pool
# TODO: this reduces the number of kernels, but it's slower!
#x = self.pad2d(padding_)._pool((H,W), stride, dilation, _insert_dims=(cout//groups,)) # (bs, groups*cin, rcout, oy, ox, H, W)
#rcout, oy, ox = x.shape[2:5]
#x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
# conv! broadcasted to (bs, groups, rcout, oy, ox, cin, H, W)
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1), keepdim=True).reshape(bs, cout, oy, ox)