mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove a lot of useless returns
This commit is contained in:
@@ -22,7 +22,6 @@ def unary_op(op, x, ret):
|
||||
elif op == UnaryOps.NEG: ret[:] = -x
|
||||
elif op == UnaryOps.SIGN: ret[:] = x.sign()
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
return ret
|
||||
|
||||
def binary_op(op, x, y, ret):
|
||||
if op == BinaryOps.ADD: ret[:] = x+y
|
||||
@@ -33,7 +32,6 @@ def binary_op(op, x, y, ret):
|
||||
elif op == BinaryOps.A: ret[:] = x
|
||||
elif op == BinaryOps.CMPEQ: ret[:] = 1.0*(x==y)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
return ret
|
||||
|
||||
def reduce_op(op, inp, ret):
|
||||
if inp.shape == ret.shape: # this is just a copy, regardless of the reduce op
|
||||
@@ -47,7 +45,6 @@ def reduce_op(op, inp, ret):
|
||||
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
return ret
|
||||
|
||||
def movement_op(op, x, ret, arg=None):
|
||||
if op == MovementOps.RESHAPE: ret[:] = x.reshape(ret.shape)
|
||||
@@ -58,7 +55,6 @@ def movement_op(op, x, ret, arg=None):
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
ret[:] = x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
return ret
|
||||
|
||||
def get_tx(x, C):
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
@@ -77,7 +73,6 @@ def conv(x,w,ret,stride,groups):
|
||||
#ijYXyx,kjyx -> iYXk ->ikYX
|
||||
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
|
||||
return ret
|
||||
|
||||
def convdw(x,grad_output,dw,stride,groups):
|
||||
C = get_conv_args(x.shape, dw.shape, stride, groups)
|
||||
@@ -88,7 +83,6 @@ def convdw(x,grad_output,dw,stride,groups):
|
||||
for g in range(C.groups):
|
||||
#'ikYX,ijYXyx -> kjyx'
|
||||
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
|
||||
return dw
|
||||
|
||||
def convdx(grad_output,w,dx,stride,groups):
|
||||
C = get_conv_args(dx.shape, w.shape, stride, groups)
|
||||
@@ -103,10 +97,8 @@ def convdx(grad_output,w,dx,stride,groups):
|
||||
for g in range(C.groups):
|
||||
tg = np.dot(ggg[:,g,:,Y,X].reshape(C.bs, -1), tw[g].reshape(C.rcout, -1))
|
||||
gdx[:, g, :, iY:iY+C.H, iX:iX+C.W] += tg.reshape((C.bs, C.cin, C.H, C.W))
|
||||
return dx
|
||||
|
||||
def processing_op(op,a,b,ret,stride,groups):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
|
||||
return ret
|
||||
|
||||
@@ -61,7 +61,7 @@ def unary_op(op, x, ret):
|
||||
float4 a = a_g[gid];
|
||||
res_g[gid] = """+code+""";
|
||||
}""")
|
||||
unop([roundup(np.prod(ret.shape))//4], None, x.cl, ret.cl)
|
||||
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
@functools.lru_cache
|
||||
@@ -100,7 +100,6 @@ def binary_op(op, x, y, ret):
|
||||
prg, is_float4 = get_binop_prg(code, tuple(complist))
|
||||
kernel_size = ((roundup(prod_list[0])//4) if is_float4 else prod_list[0]) if len(dimlist) > 0 else 1
|
||||
prg.binop(cl_queue, [kernel_size], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:]))
|
||||
return ret
|
||||
|
||||
def reduce_op(op, inp, ret):
|
||||
if op == ReduceOps.SUM:
|
||||
@@ -137,12 +136,11 @@ def reduce_op(op, inp, ret):
|
||||
}
|
||||
res_g[gid] = out;
|
||||
}""")
|
||||
reduce([np.prod(ret.shape)], None, inp.cl,
|
||||
i32(np.prod(inp.shape)//np.prod(ret.shape)), ret.cl,
|
||||
i32(np.prod(ret.shape)), i32(len(ret.shape)),
|
||||
reduce([prod(ret.shape)], None, inp.cl,
|
||||
i32(prod(inp.shape)//prod(ret.shape)), ret.cl,
|
||||
i32(prod(ret.shape)), i32(len(ret.shape)),
|
||||
buffer_np(np.array(inp.shape, dtype=np.int32)),
|
||||
buffer_np(np.array(ret.shape, dtype=np.int32)))
|
||||
return ret
|
||||
|
||||
def reshape(x, ret):
|
||||
cl.enqueue_copy(cl_queue, ret.cl, x.cl)
|
||||
@@ -162,10 +160,9 @@ def perm_axis(inp, order, ret):
|
||||
}
|
||||
res_g[gid] = a_g[idx];
|
||||
}""")
|
||||
perm([np.prod(inp.shape)], None, inp.cl, ret.cl, i32(len(inp.shape)),
|
||||
perm([prod(inp.shape)], None, inp.cl, ret.cl, i32(len(inp.shape)),
|
||||
buffer_np(np.array(inp.shape, dtype=np.int32)),
|
||||
buffer_np(np.array(order, dtype=np.int32)))
|
||||
return ret
|
||||
|
||||
# TODO: merge this with perm axis
|
||||
def inner_slice(x, arg, ret):
|
||||
@@ -185,18 +182,16 @@ def inner_slice(x, arg, ret):
|
||||
}
|
||||
output[gid] = zero ? input[iptr] : 0.0;
|
||||
}""")
|
||||
gslice([np.prod(ret.shape)], None,
|
||||
x.cl, ret.cl, i32(np.prod(ret.shape)), i32(len(ret.shape)),
|
||||
gslice([prod(ret.shape)], None,
|
||||
x.cl, ret.cl, i32(prod(ret.shape)), i32(len(ret.shape)),
|
||||
buffer_np(np.array(x.shape, dtype=np.int32)),
|
||||
buffer_np(np.array(ret.shape, dtype=np.int32)),
|
||||
buffer_np(np.array(shift, dtype=np.int32)))
|
||||
return ret
|
||||
|
||||
def movement_op(op, x, ret, arg=None):
|
||||
if op == MovementOps.RESHAPE: reshape(x, ret)
|
||||
elif op == MovementOps.PERMUTE: perm_axis(x, arg, ret)
|
||||
elif op == MovementOps.SLICE: inner_slice(x, arg, ret)
|
||||
return ret
|
||||
|
||||
def conv(x,w,ret,stride,groups):
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups)
|
||||
@@ -227,7 +222,6 @@ def conv(x,w,ret,stride,groups):
|
||||
}""")
|
||||
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C])
|
||||
return ret
|
||||
|
||||
# tensx = (bs, groups*cin, iy, ix)
|
||||
# tensw = (groups*rcout, cin, H, W)
|
||||
@@ -255,7 +249,6 @@ def convdw(x,grad_output,dw,stride,groups):
|
||||
dw[get_global_id(0)*H*W + y*W + x] = acc;
|
||||
}""")
|
||||
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C])
|
||||
return dw
|
||||
|
||||
def convdx(grad_output,w,dx,stride,groups):
|
||||
C = get_conv_args(dx.shape, w.shape, stride, groups)
|
||||
@@ -284,10 +277,8 @@ def convdx(grad_output,w,dx,stride,groups):
|
||||
}
|
||||
""")
|
||||
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C])
|
||||
return dx
|
||||
|
||||
def processing_op(op,a,b,ret,stride,groups):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
|
||||
return ret
|
||||
|
||||
@@ -36,7 +36,6 @@ def convdw(x,grad_output,dw,stride,groups):
|
||||
grad_weight = torch.conv2d(x, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin)
|
||||
grad_weight = grad_weight.reshape(C.bs, C.groups, C.cin, C.rcout, *grad_weight.shape[2:]).sum(dim=0).transpose(2, 1)
|
||||
dw[:] = grad_weight.reshape(C.groups*C.rcout, C.cin, *grad_weight.shape[3:])[:, :, :dw.shape[2], :dw.shape[3]]
|
||||
return dw
|
||||
|
||||
def processing_op(op,x,w,ret,stride,groups):
|
||||
if op == ProcessingOps.CONV:
|
||||
@@ -52,4 +51,3 @@ def processing_op(op,x,w,ret,stride,groups):
|
||||
ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding)
|
||||
elif op == ProcessingOps.CONVDW:
|
||||
convdw(x,w,ret,stride,groups)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user