mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
IMAGE == 1, add reshape to the ast
This commit is contained in:
@@ -3,7 +3,7 @@ from tinygrad.ops import MovementOps, ProcessingOps
|
||||
# input format is N, H x W, C//4 x 4
|
||||
# dweight format is oc//4 x ch, cw x 4(oc)
|
||||
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
|
||||
def preprocessing_op(x,w,C):
|
||||
def preprocessing_op(x,w,C,make_image=True):
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
#print(x.shape, w.shape)
|
||||
|
||||
@@ -67,7 +67,9 @@ def preprocessing_op(x,w,C):
|
||||
bw = bw.op.src[0]
|
||||
if bw.realized:
|
||||
# weights are static
|
||||
w.realize().image
|
||||
wr = w.realize() #.image
|
||||
if make_image:
|
||||
wr.image
|
||||
return x,w,C
|
||||
|
||||
def postprocessing_op(ret, C, C_initial):
|
||||
|
||||
@@ -13,6 +13,7 @@ sys.tracebacklimit = 20
|
||||
|
||||
OPT = int(os.getenv("OPT", "1"))
|
||||
NOCONV = int(os.getenv("NOCONV", "0"))
|
||||
IMAGE = int(os.getenv("IMAGE", "0"))
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
@@ -45,9 +46,14 @@ def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer],
|
||||
else:
|
||||
raise NotImplementedError(f"unknown LoadOp {self.op.op}")
|
||||
|
||||
# TODO: these two are generic, replace them?
|
||||
def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
src = self.op.src[0]
|
||||
|
||||
# fuse RESHAPE and ReduceOps
|
||||
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
|
||||
return _realize_reduceops_w_shape(src, output_shape = self.op.arg)
|
||||
|
||||
real_src = src.realize(self.device)
|
||||
return real_src.movement_op(self.op.op, self.op.arg), [real_src], MovementOps
|
||||
|
||||
def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
@@ -55,18 +61,20 @@ def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBu
|
||||
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w], ProcessingOps
|
||||
|
||||
# this supports late merging an upstream Elementwise op
|
||||
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
def _realize_reduceops_w_shape(self:LazyBuffer, output_shape=None) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
# TODO: this can also corealize a binary op after the reduce, not just before
|
||||
src = self.op.src[0]
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
|
||||
# this is the new version, deprecate _processing_op
|
||||
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:x.realize(self.device) for x in get_buffers(src.op)}
|
||||
ast = LazyOp(self.op.op, (realize_buffers(real_srcs, src.op),), self.op.arg)
|
||||
return self.dbuffer.exec_ast(ast), list(real_srcs.values()), ReduceOps
|
||||
else:
|
||||
real_src = src.realize(self.device)
|
||||
real_srcs = {src:real_src}
|
||||
ast = LazyOp(self.op.op, (real_src,), self.op.arg)
|
||||
return self.dbuffer.exec_ast(ast), [real_src], ReduceOps
|
||||
if output_shape is not None: ast = LazyOp(MovementOps.RESHAPE, (ast, ), output_shape)
|
||||
return self.dbuffer.exec_ast(ast), list(real_srcs.values()), ReduceOps
|
||||
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: return _realize_reduceops_w_shape(self)
|
||||
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
@@ -106,8 +114,9 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
for x in real_srcs.keys():
|
||||
if real_srcs[x] is None:
|
||||
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device)
|
||||
ret = self.dbuffer.exec_ast(realize_buffers(real_srcs, self.op))
|
||||
return ret.movement_op(MovementOps.RESHAPE, self.shape), [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
|
||||
ast = LazyOp(MovementOps.RESHAPE, (realize_buffers(real_srcs, self.op), ), self.shape)
|
||||
ret = self.dbuffer.exec_ast(ast)
|
||||
return ret, [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
|
||||
|
||||
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
|
||||
|
||||
@@ -240,6 +249,46 @@ class LazyBuffer:
|
||||
|
||||
def processing_op(self:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
x = self
|
||||
|
||||
if IMAGE == 1:
|
||||
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore
|
||||
Cold = C
|
||||
x,w,C = preprocessing_op(x, w, Cold, False)
|
||||
|
||||
# set up the conv
|
||||
# (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4)
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.iy, C.ix, C.groups, C.cin))
|
||||
|
||||
# padding (implicit is fine in image)
|
||||
x = x.slice(((0, x.shape[0]), (-C.py, x.shape[1]+C.py_), (-C.px, x.shape[2]+C.px_), (0, x.shape[3]), (0, x.shape[4])))
|
||||
|
||||
x = x.movement_op(MovementOps.STRIDED, (
|
||||
(C.bs, x.shape[1]*x.shape[2]*C.groups*C.cin),
|
||||
(C.oy, C.sy*x.shape[2]*C.groups*C.cin), (C.ox, C.sx*C.groups*C.cin),
|
||||
(C.groups, C.cin), (1, 1), (1, 1),
|
||||
(C.H, C.dy*x.shape[2]*C.groups*C.cin), (C.W, C.dx*C.groups*C.cin), (C.cin//4 if C.cin >= 4 else 1, 4), (4 if C.cin >= 4 else 1, 1)
|
||||
))
|
||||
x = x.movement_op(MovementOps.EXPAND, (C.bs, C.oy, C.ox, C.groups, C.rcout//4 if C.rcout >= 4 else 1, 4 if C.rcout >= 4 else 1, C.H, C.W, C.cin//4 if C.cin >= 4 else 1, 4 if C.cin >= 4 else 1))
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout//4, 4, C.H, C.W, C.cin//4 if C.cin >= 4 else 1, 4 if C.cin >= 4 else 1))
|
||||
|
||||
# set up the weights
|
||||
if C.cin == 1:
|
||||
# depthwise
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4, C.H, C.W, 4))
|
||||
w = w.movement_op(MovementOps.PERMUTE, (0,3,1,2))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (1, 1, 1, C.cout//4, 4, C.H, C.W, 1, 1)) \
|
||||
.movement_op(MovementOps.EXPAND, (C.bs, C.oy, C.ox, C.cout//4, 4, C.H, C.W, 1, 1))
|
||||
else:
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4, C.H, C.cin//4, C.W, 4, 4))
|
||||
w = w.movement_op(MovementOps.PERMUTE, (0,4,1,3,2,5))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (1, 1, 1, C.cout//4, 4, C.H, C.W, C.cin//4, 4)) \
|
||||
.movement_op(MovementOps.EXPAND, (C.bs, C.oy, C.ox, C.cout//4, 4, C.H, C.W, C.cin//4, 4))
|
||||
|
||||
# now do the conv in this space
|
||||
ret = x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.oy, C.ox, C.cout//4, 4, 1, 1, 1, 1))
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs*C.oy, C.ox*C.cout//4, 4)).contiguous() #True)
|
||||
return postprocessing_op(ret, C, Cold)
|
||||
|
||||
# TODO: fixup C?
|
||||
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False):
|
||||
x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
|
||||
@@ -98,6 +98,16 @@ def get_first_reduce(shapes):
|
||||
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
|
||||
class ASTKernel:
|
||||
def __init__(self, ast:LazyOp):
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
self.key = str(ast)
|
||||
|
||||
# if the AST ends with a RESHAPE, we remove it and create the buffer accordingly
|
||||
if ast.op == MovementOps.RESHAPE:
|
||||
output_shape = ast.arg
|
||||
ast = ast.src[0]
|
||||
else:
|
||||
output_shape = None
|
||||
|
||||
self.info = get_lazyop_info(ast)
|
||||
self.bufs = dedup(get_buffers(ast))
|
||||
reduceops = [x for x in get_lazyops(ast) if x.op in ReduceOps]
|
||||
@@ -107,17 +117,15 @@ class ASTKernel:
|
||||
self.ast = ast
|
||||
|
||||
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
|
||||
self.ret = type(self.bufs[0])(self.info.shape)
|
||||
self.bufs = [self.ret] + self.bufs
|
||||
self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape)
|
||||
self.ret.cl # does the allocation
|
||||
self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs
|
||||
|
||||
# check valid AST kernel
|
||||
assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape"
|
||||
assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape"
|
||||
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
|
||||
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
self.key = str(ast)
|
||||
|
||||
def process(self):
|
||||
# get shape, strides, and offset
|
||||
# if it's a multiview buffer we take the final view
|
||||
|
||||
Reference in New Issue
Block a user