that's not where i thought we'd lose lines...

This commit is contained in:
George Hotz
2022-07-08 23:52:38 -07:00
parent 75e1848b09
commit c39a245696
2 changed files with 3 additions and 9 deletions

View File

@@ -49,12 +49,6 @@ class CPUBuffer(np.ndarray):
gx = x.ravel().reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
tx = np.lib.stride_tricks.as_strided(gx,
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
strides=(*gx.strides[0:3], gx.strides[3]*C.sy, gx.strides[4]*C.sx, gx.strides[3]*C.dy, gx.strides[4]*C.dx),
writeable=False,
)
strides=(*gx.strides[0:3], gx.strides[3]*C.sy, gx.strides[4]*C.sx, gx.strides[3]*C.dy, gx.strides[4]*C.dx))
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
tmp = np.empty((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype)
for g in range(C.groups):
#ijYXyx,kjyx -> iYXk ->ikYX
tmp[:,g] = np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
return np.einsum("nGChwHW, GkCHW -> nGkhw", tx, tw).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)

View File

@@ -80,7 +80,7 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
if nm(ret) not in G.nodes: G.add_node(nm(ret))
if getattr(ret, "st", None) is not None and not ret.st.contiguous: # checked twice to make type checker happy
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
dashed = True
else:
G.nodes[nm(ret)]['label'] = str(ret.shape)