From 40f8eb73832da3dd278b5fd75439033163dcdb06 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 9 Jun 2022 09:19:52 -0700 Subject: [PATCH] namedtuple for conv_args --- tinygrad/helpers.py | 9 ++++++--- tinygrad/llops/ops_gpu.py | 18 ++++++------------ tinygrad/llops/ops_torch.py | 16 ++++++---------- tinygrad/mlops.py | 14 ++++++-------- 4 files changed, 24 insertions(+), 33 deletions(-) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 028221bb5b..f358d64055 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,4 +1,5 @@ import numpy as np +from collections import namedtuple def binary_broadcast(x_shape, y_shape, extra=False): n_dims = max(len(x_shape), len(y_shape)) @@ -22,14 +23,16 @@ def binary_broadcast(x_shape, y_shape, extra=False): return (shape_ret, dimlist, complist) if extra else shape_ret def get_conv_args(x_shape, w_shape, stride, groups): + conv_args = namedtuple('conv_args', + ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs']) cout,cin,H,W = w_shape - ys,xs = stride + ys,xs = (stride, stride) if isinstance(stride, int) else stride bs,cin_,iy,ix = x_shape oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs - if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w.shape}. ({cin*ctx.groups} vs. {cin_})") + if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})") assert cout % groups == 0 rcout = cout//groups - return H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs + return conv_args(H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs) from enum import Enum UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"]) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 3be6e38ef5..32fe115f7c 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -228,11 +228,7 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False): osize) return c - -# TODO: combine any of these three? -def conv(x,w,ret,conv_args): - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args - +def conv(x,w,ret,C): # input = (bs, groups, cin, iy, ix) # weight = (groups, rcout, cin, H, W) # output = (bs, groups, rcout, oy, ox) @@ -259,15 +255,14 @@ def conv(x,w,ret,conv_args): output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc; }""") - conv_prg([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args]) + conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C]) return ret # tensx = (bs, groups*cin, iy, ix) # tensw = (groups*rcout, cin, H, W) # ggg = (bs, groups*rout, oy, ox) -def convdw(x,grad_output,dw,conv_args): - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args +def convdw(x,grad_output,dw,C): convdw_prg = clbuild("convdw", """ __kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw, int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { @@ -287,11 +282,10 @@ def convdw(x,grad_output,dw,conv_args): } } dw[get_global_id(0)*H*W + y*W + x] = acc; }""") - convdw_prg([groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args]) + convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C]) return dw -def convdx(w,grad_output,dx,conv_args): - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args +def convdx(w,grad_output,dx,C): convdx_prg = clbuild("convdx", """ __kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx, int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { @@ -316,5 +310,5 @@ def convdx(w,grad_output,dx,conv_args): } } } """) - convdx_prg([bs, groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args]) + convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C]) return dx diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index d6e12de07d..5807e0588e 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -23,18 +23,14 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, reshape, perm # ************* processing ops ************* -def conv(x,w,ret,conv_args): - # TODO: replace conv_args with stride and groups everywhere - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args - ret[:] = torch.nn.functional.conv2d(x, w, stride=(ys,xs), groups=groups) +def conv(x,w,ret,C): + ret[:] = torch.nn.functional.conv2d(x, w, stride=(C.ys,C.xs), groups=C.groups) return ret -def convdw(x,grad_output,dw,conv_args): - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args - dw[:] = torch.nn.grad.conv2d_weight(x, dw.shape, grad_output, stride=(ys,xs), groups=groups) +def convdw(x,grad_output,dw,C): + dw[:] = torch.nn.grad.conv2d_weight(x, dw.shape, grad_output, stride=(C.ys,C.xs), groups=C.groups) return dw -def convdx(w,grad_output,dx,conv_args): - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args - dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=(ys,xs), groups=groups) +def convdx(w,grad_output,dx,C): + dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=(C.ys,C.xs), groups=C.groups) return dx diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 4ca7cb821b..c8cb3571f6 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -179,14 +179,12 @@ class Matmul(Function): class Conv2D(Function): def forward(ctx, x, w, stride=1, groups=1): - if isinstance(ctx.stride, int): ctx.stride = (ctx.stride, ctx.stride) - ctx.save_for_backward(x,w) - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args = get_conv_args(x.shape, w.shape, ctx.stride, ctx.groups) - return ctx.op.conv(x, w, ctx.buffer((bs, groups*rcout, oy, ox)), conv_args) + C = get_conv_args(x.shape, w.shape, stride, groups) + ctx.save_for_backward(x,w,C) + return ctx.op.conv(x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), C) def backward(ctx, grad_output): - x, w = ctx.saved_tensors - H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args = get_conv_args(x.shape, w.shape, ctx.stride, ctx.groups) - dx = ctx.op.convdx(w, grad_output, ctx.buffer((bs, groups*cin, iy, ix)), conv_args) if ctx.needs_input_grad[0] else None - dw = ctx.op.convdw(x, grad_output, ctx.buffer((groups*rcout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None + x, w, C = ctx.saved_tensors + dx = ctx.op.convdx(w, grad_output, ctx.buffer((C.bs, C.groups*C.cin, C.iy, C.ix)), C) if ctx.needs_input_grad[0] else None + dw = ctx.op.convdw(x, grad_output, ctx.buffer((C.groups*C.rcout, C.cin, C.H, C.W)), C) if ctx.needs_input_grad[1] else None return dx, dw \ No newline at end of file