mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llcpu convs work
This commit is contained in:
@@ -4,6 +4,7 @@ from tinygrad.helpers import UnaryOps, BinaryOps, ReduceOps
|
||||
class CPUBuffer(np.ndarray):
|
||||
def __new__(cls, shape, hostbuf=None):
|
||||
if hostbuf is not None:
|
||||
#print(shape, hostbuf.shape)
|
||||
return super().__new__(cls, shape, buffer=hostbuf.data)
|
||||
else:
|
||||
return super().__new__(cls, shape)
|
||||
@@ -90,6 +91,7 @@ def convdw(x,grad_output,dw,conv_args):
|
||||
tx = get_tx(x, conv_args)
|
||||
ggg = grad_output.reshape(bs,groups,rcout,oy,ox)
|
||||
gdw = dw.reshape((groups,rcout,cin,H,W))
|
||||
gdw[:] = 0
|
||||
for g in range(groups):
|
||||
#'ikYX,ijYXyx -> kjyx'
|
||||
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
|
||||
@@ -97,10 +99,10 @@ def convdw(x,grad_output,dw,conv_args):
|
||||
|
||||
def convdx(w,grad_output,dx,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
OY,OX = dx.shape[2:4]
|
||||
ggg = grad_output.reshape(bs,groups,rcout,oy,ox)
|
||||
tw = w.reshape(groups, rcout, cin, H, W)
|
||||
gdx = dx.reshape((bs,groups,cin,OY,OX))
|
||||
gdx = dx.reshape((bs,groups,cin,iy,ix))
|
||||
gdx[:] = 0
|
||||
for k in range(oy*ox):
|
||||
Y, X = k//ox, k%ox
|
||||
iY,iX = Y*ys, X*xs
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from tinygrad.helpers import binary_broadcast, UnaryOps, BinaryOps, ReduceOps
|
||||
from ..tensor import Function
|
||||
from ..llops.opencl import GPUBuffer as Buffer
|
||||
from ..llops.opencl import unary_op, binary_op, reduce_op, perm_axis, inner_slice
|
||||
from ..llops.opencl import matmul, conv, convdw, convdx
|
||||
|
||||
if int(os.getenv("LLCPU", 0)) == 1:
|
||||
from ..llops.cpu import CPUBuffer as Buffer
|
||||
from ..llops.cpu import unary_op, binary_op, reduce_op, perm_axis, inner_slice
|
||||
from ..llops.cpu import matmul, conv, convdw, convdx
|
||||
else:
|
||||
from ..llops.opencl import GPUBuffer as Buffer
|
||||
from ..llops.opencl import unary_op, binary_op, reduce_op, perm_axis, inner_slice
|
||||
from ..llops.opencl import matmul, conv, convdw, convdx
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user