From 365e62a609154ced94e5acb3becfff25dcdc9401 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 5 Jun 2022 14:33:06 -0700 Subject: [PATCH] refactor out matmul --- tinygrad/llops/gpu.py | 38 ++++++++++++++++++++++++++++++- tinygrad/ops/ops_gpu.py | 50 ++++++----------------------------------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/tinygrad/llops/gpu.py b/tinygrad/llops/gpu.py index 518dd1bb2d..44cd90eec3 100644 --- a/tinygrad/llops/gpu.py +++ b/tinygrad/llops/gpu.py @@ -188,4 +188,40 @@ def inner_slice(ctx, x, arg): buffer_np(ctx, np.array(x.shape, dtype=np.int32)), buffer_np(ctx, np.array(ret.shape, dtype=np.int32)), buffer_np(ctx, np.array(shift, dtype=np.int32))) - return ret \ No newline at end of file + return ret + +# c = a@b +def matmul(a, b, c, transpose_a=False, transpose_b=False): + cnt = np.prod(a.shape[0:-2]) if len(a.shape) > 2 else 1 + isize, msize, osize = i32(a.shape[-2]), i32(a.shape[-1]), i32(c.shape[-1]) + if transpose_a: isize,msize = msize,isize + assert isize == c.shape[-2] + assert (msize == b.shape[-1]) if transpose_b else (msize == b.shape[-2]) + assert (osize == b.shape[-2]) if transpose_b else (osize == b.shape[-1]) + + matmul_prg = clbuild("matmul", """ + __kernel void matmul( + __global const float *input, __global const float *weight, __global float *res, + int isize, int is0, int is1, int msize, int ws0, int ws1, int osize + ) { + int stride = get_global_id(2); + + int X = get_global_id(0); // isize + int Y = get_global_id(1); // osize + + float ret = 0.0; + for (int x = 0; x < msize; x++) { + ret += input[X * is0 + x * is1 + isize*msize*stride] * + weight[Y * ws0 + x * ws1 + msize*osize*stride]; + } + + res[X * osize + Y + isize*osize*stride] = ret; + }""") + + matmul_prg([isize, osize, cnt], None, + a.cl, b.cl, c.cl, + isize, + msize if not transpose_a else i32(1), i32(1) if not transpose_a else isize, + msize, + i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1), + osize) diff --git a/tinygrad/ops/ops_gpu.py b/tinygrad/ops/ops_gpu.py index bb64b6ff4b..0fe196b08a 100644 --- a/tinygrad/ops/ops_gpu.py +++ b/tinygrad/ops/ops_gpu.py @@ -1,7 +1,7 @@ import pyopencl as cl import numpy as np from ..tensor import Function -from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice +from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul i32 = np.int32 @@ -141,53 +141,17 @@ class Slice(Function): class Matmul(Function): def forward(ctx, input, weight): assert input.shape[-1] == weight.shape[-2] - cnt = np.prod(input.shape[0:-2]) if len(input.shape) > 2 else 1 - isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1]) - ret = buffer_new(ctx, list(input.shape[0:-2])+[isize, osize]) - - matmul = clbuild("matmul", """ - __kernel void matmul( - __global const float *input, __global const float *weight, __global float *res, - int isize, int is0, int is1, int msize, int ws0, int ws1, int osize - ) { - int stride = get_global_id(2); - - int X = get_global_id(0); // isize - int Y = get_global_id(1); // osize - - float ret = 0.0; - for (int x = 0; x < msize; x++) { - ret += input[X * is0 + x * is1 + isize*msize*stride] * - weight[Y * ws0 + x * ws1 + msize*osize*stride]; - } - - res[X * osize + Y + isize*osize*stride] = ret; - }""") - ctx.save_for_backward(input, weight, matmul, cnt) - - # (isize,msize) x (msize,osize) = (isize,osize) - matmul([isize, osize, cnt], None, - input.cl, weight.cl, ret.cl, isize, - msize, i32(1), msize, i32(1), osize, osize) + ret = buffer_new(ctx, list(input.shape[0:-1])+[weight.shape[-1]]) + ctx.save_for_backward(input, weight) + matmul(input, weight, ret) return ret def backward(ctx, grad_output): - input, weight, matmul, cnt = ctx.saved_tensors - isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1]) - + input, weight = ctx.saved_tensors grad_input = buffer_new(ctx, input.shape) grad_weight = buffer_new(ctx, weight.shape) - - # (isize,osize) x (msize,osize) = (isize,msize) - matmul([isize, msize, cnt], None, - grad_output.cl, weight.cl, grad_input.cl, isize, - osize, i32(1), osize, osize, i32(1), msize) - - # (isize,msize) x (isize,osize) = (msize,osize) - matmul([msize, osize, cnt], None, - input.cl, grad_output.cl, grad_weight.cl, msize, - i32(1), msize, isize, i32(1), osize, osize) - + matmul(grad_output, weight, grad_input, transpose_b=True) + matmul(input, grad_output, grad_weight, transpose_a=True) return grad_input, grad_weight class Conv2D(Function):