mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
IMAGE == 1, add reshape to the ast
This commit is contained in:
@@ -3,7 +3,7 @@ from tinygrad.ops import MovementOps, ProcessingOps
|
||||
# input format is N, H x W, C//4 x 4
|
||||
# dweight format is oc//4 x ch, cw x 4(oc)
|
||||
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
|
||||
def preprocessing_op(x,w,C):
|
||||
def preprocessing_op(x,w,C,make_image=True):
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
#print(x.shape, w.shape)
|
||||
|
||||
@@ -67,7 +67,9 @@ def preprocessing_op(x,w,C):
|
||||
bw = bw.op.src[0]
|
||||
if bw.realized:
|
||||
# weights are static
|
||||
w.realize().image
|
||||
wr = w.realize() #.image
|
||||
if make_image:
|
||||
wr.image
|
||||
return x,w,C
|
||||
|
||||
def postprocessing_op(ret, C, C_initial):
|
||||
|
||||
Reference in New Issue
Block a user