IMAGE == 1, add reshape to the ast

This commit is contained in:
George Hotz
2023-01-11 20:56:03 -08:00
parent 9ff6c532eb
commit 3ea38cac72
3 changed files with 73 additions and 14 deletions

View File

@@ -3,7 +3,7 @@ from tinygrad.ops import MovementOps, ProcessingOps
# input format is N, H x W, C//4 x 4
# dweight format is oc//4 x ch, cw x 4(oc)
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
def preprocessing_op(x,w,C):
def preprocessing_op(x,w,C,make_image=True):
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
#print(x.shape, w.shape)
@@ -67,7 +67,9 @@ def preprocessing_op(x,w,C):
bw = bw.op.src[0]
if bw.realized:
# weights are static
w.realize().image
wr = w.realize() #.image
if make_image:
wr.image
return x,w,C
def postprocessing_op(ret, C, C_initial):