diff --git a/tinygrad/llops/cpu.py b/tinygrad/llops/cpu.py index f347cb863c..9d283d7021 100644 --- a/tinygrad/llops/cpu.py +++ b/tinygrad/llops/cpu.py @@ -4,6 +4,7 @@ from tinygrad.helpers import UnaryOps, BinaryOps, ReduceOps class CPUBuffer(np.ndarray): def __new__(cls, shape, hostbuf=None): if hostbuf is not None: + #print(shape, hostbuf.shape) return super().__new__(cls, shape, buffer=hostbuf.data) else: return super().__new__(cls, shape) @@ -90,6 +91,7 @@ def convdw(x,grad_output,dw,conv_args): tx = get_tx(x, conv_args) ggg = grad_output.reshape(bs,groups,rcout,oy,ox) gdw = dw.reshape((groups,rcout,cin,H,W)) + gdw[:] = 0 for g in range(groups): #'ikYX,ijYXyx -> kjyx' gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3))) @@ -97,10 +99,10 @@ def convdw(x,grad_output,dw,conv_args): def convdx(w,grad_output,dx,conv_args): H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args - OY,OX = dx.shape[2:4] ggg = grad_output.reshape(bs,groups,rcout,oy,ox) tw = w.reshape(groups, rcout, cin, H, W) - gdx = dx.reshape((bs,groups,cin,OY,OX)) + gdx = dx.reshape((bs,groups,cin,iy,ix)) + gdx[:] = 0 for k in range(oy*ox): Y, X = k//ox, k%ox iY,iX = Y*ys, X*xs diff --git a/tinygrad/ops/ops_gpu.py b/tinygrad/ops/ops_gpu.py index 8720eb3fd8..ea303c4e07 100644 --- a/tinygrad/ops/ops_gpu.py +++ b/tinygrad/ops/ops_gpu.py @@ -1,9 +1,16 @@ +import os import numpy as np from tinygrad.helpers import binary_broadcast, UnaryOps, BinaryOps, ReduceOps from ..tensor import Function -from ..llops.opencl import GPUBuffer as Buffer -from ..llops.opencl import unary_op, binary_op, reduce_op, perm_axis, inner_slice -from ..llops.opencl import matmul, conv, convdw, convdx + +if int(os.getenv("LLCPU", 0)) == 1: + from ..llops.cpu import CPUBuffer as Buffer + from ..llops.cpu import unary_op, binary_op, reduce_op, perm_axis, inner_slice + from ..llops.cpu import matmul, conv, convdw, convdx +else: + from ..llops.opencl import GPUBuffer as Buffer + from ..llops.opencl import unary_op, binary_op, reduce_op, perm_axis, inner_slice + from ..llops.opencl import matmul, conv, convdw, convdx # ************* unary ops *************