This commit is contained in:
George Hotz
2022-06-08 23:39:54 -07:00
parent 5a533fc073
commit 214fb8c974
2 changed files with 56 additions and 55 deletions

View File

@@ -14,21 +14,21 @@ def select_llops(ops):
class UnaryOp(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return ll.unary_op(ctx.fop, input, ll.Buffer(input.shape))
return ctx.op.unary_op(ctx.fop, input, ctx.op.Buffer(input.shape))
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return ll.binary_op(ctx.bop, input, grad_output, ll.Buffer(input.shape))
return ctx.op.binary_op(ctx.bop, input, grad_output, ctx.op.Buffer(input.shape))
class ReLU(UnaryOp):
fop = UnaryOps.RELU
def backward(ctx, grad_output):
input, = ctx.saved_tensors
ret = ll.Buffer(input.shape)
ll.unary_op(UnaryOps.SIGN, input, ret)
ll.unary_op(UnaryOps.RELU, ret, ret)
return ll.binary_op(BinaryOps.MUL, ret, grad_output, ret)
ret = ctx.op.Buffer(input.shape)
ctx.op.unary_op(UnaryOps.SIGN, input, ret)
ctx.op.unary_op(UnaryOps.RELU, ret, ret)
return ctx.op.binary_op(BinaryOps.MUL, ret, grad_output, ret)
class Log(UnaryOp):
fop = UnaryOps.LOG
@@ -36,7 +36,7 @@ class Log(UnaryOp):
class Exp(UnaryOp):
def forward(ctx, input):
ret = ll.unary_op(UnaryOps.EXP, input, ll.Buffer(input.shape))
ret = ctx.op.unary_op(UnaryOps.EXP, input, ctx.op.Buffer(input.shape))
ctx.save_for_backward(ret) # we save the output here, not the input
return ret
@@ -50,81 +50,81 @@ def reduce_shape(shape, axis):
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input.shape)
return ll.reduce_op(ReduceOps.SUM, input, ll.Buffer(reduce_shape(input.shape, axis)))
return ctx.op.reduce_op(ReduceOps.SUM, input, ctx.op.Buffer(reduce_shape(input.shape, axis)))
def backward(ctx, grad_output):
shape_input, = ctx.saved_tensors
# NOTE: the b Buffer isn't used, since this is just for broadcast
ret = ll.Buffer(shape_input)
return ll.binary_op(BinaryOps.A, grad_output, ret, ret)
ret = ctx.op.Buffer(shape_input)
return ctx.op.binary_op(BinaryOps.A, grad_output, ret, ret)
class Max(Function):
def forward(ctx, input, axis=None):
ret = ll.reduce_op(ReduceOps.MAX, input, ll.Buffer(reduce_shape(input.shape, axis)))
ret = ctx.op.reduce_op(ReduceOps.MAX, input, ctx.op.Buffer(reduce_shape(input.shape, axis)))
ctx.save_for_backward(input, ret)
return ret
def backward(ctx, grad_output):
input, ret = ctx.saved_tensors
ret2 = ll.binary_op(BinaryOps.CMPEQ, input, ret, ll.Buffer(input.shape))
div = ll.reduce_op(ReduceOps.SUM, ret2, ll.Buffer(grad_output.shape))
ll.binary_op(BinaryOps.DIV, div, ret2, ret2)
return ll.binary_op(BinaryOps.MUL, ret2, grad_output, ret2)
ret2 = ctx.op.binary_op(BinaryOps.CMPEQ, input, ret, ctx.op.Buffer(input.shape))
div = ctx.op.reduce_op(ReduceOps.SUM, ret2, ctx.op.Buffer(grad_output.shape))
ctx.op.binary_op(BinaryOps.DIV, div, ret2, ret2)
return ctx.op.binary_op(BinaryOps.MUL, ret2, grad_output, ret2)
# ************* binary ops *************
def unbroadcast(out, in_sh):
return ll.reduce_op(ReduceOps.SUM, out, ll.Buffer(in_sh))
def unbroadcast(ctx, out, in_sh):
return ctx.op.reduce_op(ReduceOps.SUM, out, ctx.op.Buffer(in_sh))
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
buf = ll.Buffer(binary_broadcast(x.shape, y.shape))
return ll.binary_op(BinaryOps.ADD, x, y, buf) #ll.Buffer(binary_broadcast(x.shape, y.shape)))
buf = ctx.op.Buffer(binary_broadcast(x.shape, y.shape))
return ctx.op.binary_op(BinaryOps.ADD, x, y, buf) #ctx.op.Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
unbroadcast(grad_output, shape_y) if ctx.needs_input_grad[1] else None
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
unbroadcast(ctx, grad_output, shape_y) if ctx.needs_input_grad[1] else None
class Sub(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return ll.binary_op(BinaryOps.SUB, x, y, ll.Buffer(binary_broadcast(x.shape, y.shape)))
return ctx.op.binary_op(BinaryOps.SUB, x, y, ctx.op.Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
neg_grad_output = ll.unary_op(UnaryOps.NEG, grad_output, ll.Buffer(grad_output.shape))
return unbroadcast(grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
unbroadcast(neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None
neg_grad_output = ctx.op.unary_op(UnaryOps.NEG, grad_output, ctx.op.Buffer(grad_output.shape))
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
unbroadcast(ctx, neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None
class Mul(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return ll.binary_op(BinaryOps.MUL, x, y, ll.Buffer(binary_broadcast(x.shape, y.shape)))
return ctx.op.binary_op(BinaryOps.MUL, x, y, ctx.op.Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
tmp = ll.Buffer(grad_output.shape)
grad_x = unbroadcast(ll.binary_op(BinaryOps.MUL, y, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
grad_y = unbroadcast(ll.binary_op(BinaryOps.MUL, x, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
tmp = ctx.op.Buffer(grad_output.shape)
grad_x = unbroadcast(ctx, ctx.op.binary_op(BinaryOps.MUL, y, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
grad_y = unbroadcast(ctx, ctx.op.binary_op(BinaryOps.MUL, x, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
class Pow(Function):
def forward(ctx, x, y):
ret = ll.Buffer(binary_broadcast(x.shape, y.shape))
ret = ctx.op.Buffer(binary_broadcast(x.shape, y.shape))
ctx.save_for_backward(x, y, ret)
return ll.binary_op(BinaryOps.POW, x, y, ret)
return ctx.op.binary_op(BinaryOps.POW, x, y, ret)
def backward(ctx, grad_output):
x,y,powxy = ctx.saved_tensors
tmp = ll.Buffer(grad_output.shape)
ll.binary_op(BinaryOps.DIV, x, powxy, tmp) # pow(x,y)/x
ll.binary_op(BinaryOps.MUL, y, tmp, tmp) # y * pow(x,y)/x
grad_x = unbroadcast(ll.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), x.shape) if ctx.needs_input_grad[0] else None
log_x = ll.unary_op(UnaryOps.LOG, x, ll.Buffer(x.shape))
ll.binary_op(BinaryOps.MUL, log_x, powxy, tmp) # log(x) * pow(x,y)
grad_y = unbroadcast(ll.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), y.shape) if ctx.needs_input_grad[1] else None
tmp = ctx.op.Buffer(grad_output.shape)
ctx.op.binary_op(BinaryOps.DIV, x, powxy, tmp) # pow(x,y)/x
ctx.op.binary_op(BinaryOps.MUL, y, tmp, tmp) # y * pow(x,y)/x
grad_x = unbroadcast(ctx, ctx.op.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), x.shape) if ctx.needs_input_grad[0] else None
log_x = ctx.op.unary_op(UnaryOps.LOG, x, ctx.op.Buffer(x.shape))
ctx.op.binary_op(BinaryOps.MUL, log_x, powxy, tmp) # log(x) * pow(x,y)
grad_y = unbroadcast(ctx, ctx.op.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), y.shape) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
# ************* movement ops *************
@@ -133,48 +133,48 @@ class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
shape = tuple(-np.prod(x.shape) // np.prod(shape) if s == -1 else s for s in shape)
return ll.reshape(x, shape) # NOTE: this is not a copy
return ctx.op.reshape(x, shape) # NOTE: this is not a copy
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return ll.reshape(grad_output, in_shape)
return ctx.op.reshape(grad_output, in_shape)
class Transpose(Function):
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
ret = ll.Buffer([x.shape[i] for i in order])
return ll.perm_axis(x, order, ret)
ret = ctx.op.Buffer([x.shape[i] for i in order])
return ctx.op.perm_axis(x, order, ret)
def backward(ctx, grad_output):
norder = np.argsort(ctx.order).tolist()
ret = ll.Buffer([grad_output.shape[i] for i in norder])
return ll.perm_axis(grad_output, norder, ret)
ret = ctx.op.Buffer([grad_output.shape[i] for i in norder])
return ctx.op.perm_axis(grad_output, norder, ret)
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape)
ret = ll.Buffer([y[1]-y[0] for y in arg])
return ll.inner_slice(x, arg, ret)
ret = ctx.op.Buffer([y[1]-y[0] for y in arg])
return ctx.op.inner_slice(x, arg, ret)
def backward(ctx, grad_output):
shape, = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
ret = ll.Buffer([y[1]-y[0] for y in narg])
return ll.inner_slice(grad_output, narg, ret)
ret = ctx.op.Buffer([y[1]-y[0] for y in narg])
return ctx.op.inner_slice(grad_output, narg, ret)
# ************* processing ops *************
class Matmul(Function):
def forward(ctx, input, weight):
assert input.shape[-1] == weight.shape[-2]
ret = ll.Buffer(list(input.shape[0:-1])+[weight.shape[-1]])
ret = ctx.op.Buffer(list(input.shape[0:-1])+[weight.shape[-1]])
ctx.save_for_backward(input, weight)
return ll.matmul(input, weight, ret)
return ctx.op.matmul(input, weight, ret)
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = ll.matmul(grad_output, weight, ll.Buffer(input.shape), transpose_b=True) if ctx.needs_input_grad[0] else None
grad_weight = ll.matmul(input, grad_output, ll.Buffer(weight.shape), transpose_a=True) if ctx.needs_input_grad[1] else None
grad_input = ctx.op.matmul(grad_output, weight, ctx.op.Buffer(input.shape), transpose_b=True) if ctx.needs_input_grad[0] else None
grad_weight = ctx.op.matmul(input, grad_output, ctx.op.Buffer(weight.shape), transpose_a=True) if ctx.needs_input_grad[1] else None
return grad_input, grad_weight
class Conv2D(Function):
@@ -192,7 +192,7 @@ class Conv2D(Function):
# output buffer
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
return ll.conv(x, w, ll.Buffer((bs, cout, oy, ox)), conv_args)
return ctx.op.conv(x, w, ctx.op.Buffer((bs, cout, oy, ox)), conv_args)
def backward(ctx, grad_output):
bs,_,oy,ox = grad_output.shape
@@ -206,6 +206,6 @@ class Conv2D(Function):
rcout = cout//ctx.groups
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
dx = ll.convdx(w, grad_output, ll.Buffer((bs, cin_, iy, ix)), conv_args) if ctx.needs_input_grad[0] else None
dw = ll.convdw(x, grad_output, ll.Buffer((cout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
dx = ctx.op.convdx(w, grad_output, ctx.op.Buffer((bs, cin_, iy, ix)), conv_args) if ctx.needs_input_grad[0] else None
dw = ctx.op.convdw(x, grad_output, ctx.op.Buffer((cout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
return dx, dw

View File

@@ -400,6 +400,7 @@ def register(name, fxn, device=Device.CPU):
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = Tensor.ops[tt.device][name]
f.device = tt.device
f.op = importlib.import_module(f".cpu", f"tinygrad.llops")
return f.apply(f, *x, **kwargs)
if getattr(Tensor, name, None) is not None:
setattr(Tensor, "_"+name, dispatch)