llcpu convs work

This commit is contained in:
George Hotz
2022-06-08 10:51:09 -07:00
parent 6bdcf5ef59
commit 81d16d105e
2 changed files with 14 additions and 5 deletions

View File

@@ -4,6 +4,7 @@ from tinygrad.helpers import UnaryOps, BinaryOps, ReduceOps
class CPUBuffer(np.ndarray):
def __new__(cls, shape, hostbuf=None):
if hostbuf is not None:
#print(shape, hostbuf.shape)
return super().__new__(cls, shape, buffer=hostbuf.data)
else:
return super().__new__(cls, shape)
@@ -90,6 +91,7 @@ def convdw(x,grad_output,dw,conv_args):
tx = get_tx(x, conv_args)
ggg = grad_output.reshape(bs,groups,rcout,oy,ox)
gdw = dw.reshape((groups,rcout,cin,H,W))
gdw[:] = 0
for g in range(groups):
#'ikYX,ijYXyx -> kjyx'
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
@@ -97,10 +99,10 @@ def convdw(x,grad_output,dw,conv_args):
def convdx(w,grad_output,dx,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
OY,OX = dx.shape[2:4]
ggg = grad_output.reshape(bs,groups,rcout,oy,ox)
tw = w.reshape(groups, rcout, cin, H, W)
gdx = dx.reshape((bs,groups,cin,OY,OX))
gdx = dx.reshape((bs,groups,cin,iy,ix))
gdx[:] = 0
for k in range(oy*ox):
Y, X = k//ox, k%ox
iY,iX = Y*ys, X*xs

View File

@@ -1,9 +1,16 @@
import os
import numpy as np
from tinygrad.helpers import binary_broadcast, UnaryOps, BinaryOps, ReduceOps
from ..tensor import Function
from ..llops.opencl import GPUBuffer as Buffer
from ..llops.opencl import unary_op, binary_op, reduce_op, perm_axis, inner_slice
from ..llops.opencl import matmul, conv, convdw, convdx
if int(os.getenv("LLCPU", 0)) == 1:
from ..llops.cpu import CPUBuffer as Buffer
from ..llops.cpu import unary_op, binary_op, reduce_op, perm_axis, inner_slice
from ..llops.cpu import matmul, conv, convdw, convdx
else:
from ..llops.opencl import GPUBuffer as Buffer
from ..llops.opencl import unary_op, binary_op, reduce_op, perm_axis, inner_slice
from ..llops.opencl import matmul, conv, convdw, convdx
# ************* unary ops *************