mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
refactor out matmul
This commit is contained in:
@@ -188,4 +188,40 @@ def inner_slice(ctx, x, arg):
|
||||
buffer_np(ctx, np.array(x.shape, dtype=np.int32)),
|
||||
buffer_np(ctx, np.array(ret.shape, dtype=np.int32)),
|
||||
buffer_np(ctx, np.array(shift, dtype=np.int32)))
|
||||
return ret
|
||||
return ret
|
||||
|
||||
# c = a@b
|
||||
def matmul(a, b, c, transpose_a=False, transpose_b=False):
|
||||
cnt = np.prod(a.shape[0:-2]) if len(a.shape) > 2 else 1
|
||||
isize, msize, osize = i32(a.shape[-2]), i32(a.shape[-1]), i32(c.shape[-1])
|
||||
if transpose_a: isize,msize = msize,isize
|
||||
assert isize == c.shape[-2]
|
||||
assert (msize == b.shape[-1]) if transpose_b else (msize == b.shape[-2])
|
||||
assert (osize == b.shape[-2]) if transpose_b else (osize == b.shape[-1])
|
||||
|
||||
matmul_prg = clbuild("matmul", """
|
||||
__kernel void matmul(
|
||||
__global const float *input, __global const float *weight, __global float *res,
|
||||
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
|
||||
) {
|
||||
int stride = get_global_id(2);
|
||||
|
||||
int X = get_global_id(0); // isize
|
||||
int Y = get_global_id(1); // osize
|
||||
|
||||
float ret = 0.0;
|
||||
for (int x = 0; x < msize; x++) {
|
||||
ret += input[X * is0 + x * is1 + isize*msize*stride] *
|
||||
weight[Y * ws0 + x * ws1 + msize*osize*stride];
|
||||
}
|
||||
|
||||
res[X * osize + Y + isize*osize*stride] = ret;
|
||||
}""")
|
||||
|
||||
matmul_prg([isize, osize, cnt], None,
|
||||
a.cl, b.cl, c.cl,
|
||||
isize,
|
||||
msize if not transpose_a else i32(1), i32(1) if not transpose_a else isize,
|
||||
msize,
|
||||
i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1),
|
||||
osize)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pyopencl as cl
|
||||
import numpy as np
|
||||
from ..tensor import Function
|
||||
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice
|
||||
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul
|
||||
|
||||
i32 = np.int32
|
||||
|
||||
@@ -141,53 +141,17 @@ class Slice(Function):
|
||||
class Matmul(Function):
|
||||
def forward(ctx, input, weight):
|
||||
assert input.shape[-1] == weight.shape[-2]
|
||||
cnt = np.prod(input.shape[0:-2]) if len(input.shape) > 2 else 1
|
||||
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
|
||||
ret = buffer_new(ctx, list(input.shape[0:-2])+[isize, osize])
|
||||
|
||||
matmul = clbuild("matmul", """
|
||||
__kernel void matmul(
|
||||
__global const float *input, __global const float *weight, __global float *res,
|
||||
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
|
||||
) {
|
||||
int stride = get_global_id(2);
|
||||
|
||||
int X = get_global_id(0); // isize
|
||||
int Y = get_global_id(1); // osize
|
||||
|
||||
float ret = 0.0;
|
||||
for (int x = 0; x < msize; x++) {
|
||||
ret += input[X * is0 + x * is1 + isize*msize*stride] *
|
||||
weight[Y * ws0 + x * ws1 + msize*osize*stride];
|
||||
}
|
||||
|
||||
res[X * osize + Y + isize*osize*stride] = ret;
|
||||
}""")
|
||||
ctx.save_for_backward(input, weight, matmul, cnt)
|
||||
|
||||
# (isize,msize) x (msize,osize) = (isize,osize)
|
||||
matmul([isize, osize, cnt], None,
|
||||
input.cl, weight.cl, ret.cl, isize,
|
||||
msize, i32(1), msize, i32(1), osize, osize)
|
||||
ret = buffer_new(ctx, list(input.shape[0:-1])+[weight.shape[-1]])
|
||||
ctx.save_for_backward(input, weight)
|
||||
matmul(input, weight, ret)
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, matmul, cnt = ctx.saved_tensors
|
||||
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
|
||||
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = buffer_new(ctx, input.shape)
|
||||
grad_weight = buffer_new(ctx, weight.shape)
|
||||
|
||||
# (isize,osize) x (msize,osize) = (isize,msize)
|
||||
matmul([isize, msize, cnt], None,
|
||||
grad_output.cl, weight.cl, grad_input.cl, isize,
|
||||
osize, i32(1), osize, osize, i32(1), msize)
|
||||
|
||||
# (isize,msize) x (isize,osize) = (msize,osize)
|
||||
matmul([msize, osize, cnt], None,
|
||||
input.cl, grad_output.cl, grad_weight.cl, msize,
|
||||
i32(1), msize, isize, i32(1), osize, osize)
|
||||
|
||||
matmul(grad_output, weight, grad_input, transpose_b=True)
|
||||
matmul(input, grad_output, grad_weight, transpose_a=True)
|
||||
return grad_input, grad_weight
|
||||
|
||||
class Conv2D(Function):
|
||||
|
||||
Reference in New Issue
Block a user