IMAGE == 1, add reshape to the ast

This commit is contained in:
George Hotz
2023-01-11 20:56:03 -08:00
parent 9ff6c532eb
commit 3ea38cac72
3 changed files with 73 additions and 14 deletions

View File

@@ -3,7 +3,7 @@ from tinygrad.ops import MovementOps, ProcessingOps
# input format is N, H x W, C//4 x 4
# dweight format is oc//4 x ch, cw x 4(oc)
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
def preprocessing_op(x,w,C):
def preprocessing_op(x,w,C,make_image=True):
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
#print(x.shape, w.shape)
@@ -67,7 +67,9 @@ def preprocessing_op(x,w,C):
bw = bw.op.src[0]
if bw.realized:
# weights are static
w.realize().image
wr = w.realize() #.image
if make_image:
wr.image
return x,w,C
def postprocessing_op(ret, C, C_initial):

View File

@@ -13,6 +13,7 @@ sys.tracebacklimit = 20
OPT = int(os.getenv("OPT", "1"))
NOCONV = int(os.getenv("NOCONV", "0"))
IMAGE = int(os.getenv("IMAGE", "0"))
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
@@ -45,9 +46,14 @@ def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer],
else:
raise NotImplementedError(f"unknown LoadOp {self.op.op}")
# TODO: these two are generic, replace them?
def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_src = self.op.src[0].realize(self.device)
src = self.op.src[0]
# fuse RESHAPE and ReduceOps
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
return _realize_reduceops_w_shape(src, output_shape = self.op.arg)
real_src = src.realize(self.device)
return real_src.movement_op(self.op.op, self.op.arg), [real_src], MovementOps
def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
@@ -55,18 +61,20 @@ def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBu
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w], ProcessingOps
# this supports late merging an upstream Elementwise op
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
def _realize_reduceops_w_shape(self:LazyBuffer, output_shape=None) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
# TODO: this can also corealize a binary op after the reduce, not just before
src = self.op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
# this is the new version, deprecate _processing_op
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:x.realize(self.device) for x in get_buffers(src.op)}
ast = LazyOp(self.op.op, (realize_buffers(real_srcs, src.op),), self.op.arg)
return self.dbuffer.exec_ast(ast), list(real_srcs.values()), ReduceOps
else:
real_src = src.realize(self.device)
real_srcs = {src:real_src}
ast = LazyOp(self.op.op, (real_src,), self.op.arg)
return self.dbuffer.exec_ast(ast), [real_src], ReduceOps
if output_shape is not None: ast = LazyOp(MovementOps.RESHAPE, (ast, ), output_shape)
return self.dbuffer.exec_ast(ast), list(real_srcs.values()), ReduceOps
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: return _realize_reduceops_w_shape(self)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
@@ -106,8 +114,9 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
for x in real_srcs.keys():
if real_srcs[x] is None:
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device)
ret = self.dbuffer.exec_ast(realize_buffers(real_srcs, self.op))
return ret.movement_op(MovementOps.RESHAPE, self.shape), [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
ast = LazyOp(MovementOps.RESHAPE, (realize_buffers(real_srcs, self.op), ), self.shape)
ret = self.dbuffer.exec_ast(ast)
return ret, [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
@@ -240,6 +249,46 @@ class LazyBuffer:
def processing_op(self:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
x = self
if IMAGE == 1:
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore
Cold = C
x,w,C = preprocessing_op(x, w, Cold, False)
# set up the conv
# (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4)
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.iy, C.ix, C.groups, C.cin))
# padding (implicit is fine in image)
x = x.slice(((0, x.shape[0]), (-C.py, x.shape[1]+C.py_), (-C.px, x.shape[2]+C.px_), (0, x.shape[3]), (0, x.shape[4])))
x = x.movement_op(MovementOps.STRIDED, (
(C.bs, x.shape[1]*x.shape[2]*C.groups*C.cin),
(C.oy, C.sy*x.shape[2]*C.groups*C.cin), (C.ox, C.sx*C.groups*C.cin),
(C.groups, C.cin), (1, 1), (1, 1),
(C.H, C.dy*x.shape[2]*C.groups*C.cin), (C.W, C.dx*C.groups*C.cin), (C.cin//4 if C.cin >= 4 else 1, 4), (4 if C.cin >= 4 else 1, 1)
))
x = x.movement_op(MovementOps.EXPAND, (C.bs, C.oy, C.ox, C.groups, C.rcout//4 if C.rcout >= 4 else 1, 4 if C.rcout >= 4 else 1, C.H, C.W, C.cin//4 if C.cin >= 4 else 1, 4 if C.cin >= 4 else 1))
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout//4, 4, C.H, C.W, C.cin//4 if C.cin >= 4 else 1, 4 if C.cin >= 4 else 1))
# set up the weights
if C.cin == 1:
# depthwise
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4, C.H, C.W, 4))
w = w.movement_op(MovementOps.PERMUTE, (0,3,1,2))
w = w.movement_op(MovementOps.RESHAPE, (1, 1, 1, C.cout//4, 4, C.H, C.W, 1, 1)) \
.movement_op(MovementOps.EXPAND, (C.bs, C.oy, C.ox, C.cout//4, 4, C.H, C.W, 1, 1))
else:
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4, C.H, C.cin//4, C.W, 4, 4))
w = w.movement_op(MovementOps.PERMUTE, (0,4,1,3,2,5))
w = w.movement_op(MovementOps.RESHAPE, (1, 1, 1, C.cout//4, 4, C.H, C.W, C.cin//4, 4)) \
.movement_op(MovementOps.EXPAND, (C.bs, C.oy, C.ox, C.cout//4, 4, C.H, C.W, C.cin//4, 4))
# now do the conv in this space
ret = x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.oy, C.ox, C.cout//4, 4, 1, 1, 1, 1))
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs*C.oy, C.ox*C.cout//4, 4)).contiguous() #True)
return postprocessing_op(ret, C, Cold)
# TODO: fixup C?
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False):
x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))

View File

@@ -98,6 +98,16 @@ def get_first_reduce(shapes):
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
class ASTKernel:
def __init__(self, ast:LazyOp):
# key for lookup in cache (can change, str might not be right)
self.key = str(ast)
# if the AST ends with a RESHAPE, we remove it and create the buffer accordingly
if ast.op == MovementOps.RESHAPE:
output_shape = ast.arg
ast = ast.src[0]
else:
output_shape = None
self.info = get_lazyop_info(ast)
self.bufs = dedup(get_buffers(ast))
reduceops = [x for x in get_lazyops(ast) if x.op in ReduceOps]
@@ -107,17 +117,15 @@ class ASTKernel:
self.ast = ast
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
self.ret = type(self.bufs[0])(self.info.shape)
self.bufs = [self.ret] + self.bufs
self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape)
self.ret.cl # does the allocation
self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs
# check valid AST kernel
assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape"
assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape"
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
# key for lookup in cache (can change, str might not be right)
self.key = str(ast)
def process(self):
# get shape, strides, and offset
# if it's a multiview buffer we take the final view